Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functions used for Code Ocean pipeline #10

Open
wants to merge 85 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
85 commits
Select commit Hold shift + click to select a range
2fd57b7
Use spikeinterface functions for DeepInteprolation training and infer…
alejoe91 Jul 11, 2023
a9072c4
Add minimal requirements
alejoe91 Jul 11, 2023
38e4128
Add separate scripts
alejoe91 Jul 14, 2023
976a73c
Create pipeline folder
alejoe91 Jul 18, 2023
ead2343
update run_spike_sorting
alejoe91 Jul 18, 2023
8594f49
Add generate_job_config function
alejoe91 Jul 18, 2023
694eaa0
Propagate JSON files to output
alejoe91 Jul 18, 2023
92e91f0
Reduce validation interval to 100ms
alejoe91 Jul 18, 2023
ece9cf3
Added collect_results script
alejoe91 Jul 18, 2023
44162b3
Change spike sorting output flders
alejoe91 Jul 18, 2023
555d9a9
Remove unused imports
alejoe91 Jul 18, 2023
67d207e
Oups
alejoe91 Jul 18, 2023
3b15043
fix sorting paths
alejoe91 Jul 18, 2023
4cd35c1
fix sorting paths 1
alejoe91 Jul 18, 2023
0a8390a
fix sorting paths 2
alejoe91 Jul 18, 2023
4eb9e45
Debug
alejoe91 Jul 18, 2023
28ed0e7
Debug1
alejoe91 Jul 18, 2023
4445737
Remove results dict
alejoe91 Jul 18, 2023
26636f4
Try to resolve base_path
alejoe91 Jul 19, 2023
60b7f21
Add session to print
alejoe91 Jul 19, 2023
cdde356
Change output folders and add super training
alejoe91 Jul 19, 2023
576ee38
Improve prints
alejoe91 Jul 19, 2023
e1e62e1
finalize super training
alejoe91 Jul 19, 2023
2001656
Fixes
alejoe91 Jul 19, 2023
ded56d1
Change relative paths
alejoe91 Jul 20, 2023
4abf984
Fix model_folder names
alejoe91 Jul 20, 2023
ff149d5
Fix inference sub folders
alejoe91 Jul 20, 2023
4c28a8f
Propagate models to results
alejoe91 Jul 20, 2023
64f4271
Set training verbose to false
alejoe91 Jul 20, 2023
e3e2232
Scale number of GPUs
alejoe91 Jul 20, 2023
4ee7afa
Adjust paths
alejoe91 Jul 20, 2023
203c3f7
Fix collect capsule
alejoe91 Jul 20, 2023
649a237
Fix collect capsule 1
alejoe91 Jul 20, 2023
f71c1dc
Fix collect capsule 2
alejoe91 Jul 20, 2023
2ca3796
Fix nb_gpus
alejoe91 Jul 20, 2023
c88957a
Specify steps per epoch
alejoe91 Jul 20, 2023
606ab18
Add remove_excess_spikes curation
alejoe91 Jul 21, 2023
c78168b
Extend to simulated data
alejoe91 Jul 21, 2023
a1e8f17
Steps per epoch 10 in debug mode
alejoe91 Jul 21, 2023
f3827c4
Move depth order later
alejoe91 Jul 21, 2023
54731b0
Move depth order later 2
alejoe91 Jul 21, 2023
a533616
Oups!
alejoe91 Jul 21, 2023
e065b00
Reintroduce debug in inference
alejoe91 Jul 21, 2023
3eadac0
Fix sorting import
alejoe91 Jul 21, 2023
01b2b0b
Optimize inference
alejoe91 Jul 22, 2023
1e496c1
Save sim to binary
alejoe91 Jul 22, 2023
1de9909
Fix zscore binary
alejoe91 Jul 22, 2023
67f43d9
Fix zscore binary 1
alejoe91 Jul 22, 2023
5d44817
Fix zscore binary 2
alejoe91 Jul 22, 2023
6b4edb7
Fix sorting eval sim
alejoe91 Jul 22, 2023
3a41f68
Fix df concatenation
alejoe91 Jul 24, 2023
60a08f1
Fix debug mode for sorting GT and add inference parallel params
alejoe91 Jul 24, 2023
c89e7ab
Add debug print
alejoe91 Jul 24, 2023
c649d27
Final cmp fix
alejoe91 Jul 24, 2023
218a57a
Add unit id column and sort columns GT
alejoe91 Jul 24, 2023
12d80a8
Add unit id column and sort columns GT 1
alejoe91 Jul 24, 2023
fab66b7
Add unit id column and sort columns GT 2
alejoe91 Jul 24, 2023
e049bb3
Debug paths
alejoe91 Jul 24, 2023
f4c8a0a
Remove resolve
alejoe91 Jul 24, 2023
b3e2257
Remove resolved and update super training
alejoe91 Jul 24, 2023
15a7489
Super-training: add probe option
alejoe91 Jul 24, 2023
3acf580
Don't max out CPU
alejoe91 Jul 24, 2023
9ce301b
Steps per epoch in super-training
alejoe91 Jul 24, 2023
6693240
Add default probeset
alejoe91 Jul 24, 2023
1d57918
Limit n_njobs
alejoe91 Jul 24, 2023
f38240d
Set n_jobs with params
alejoe91 Jul 25, 2023
bf26914
Handle sorting errors in run_sorting function
alejoe91 Jul 31, 2023
c2672ae
Oups
alejoe91 Jul 31, 2023
122707b
Protect against small mismatches in sampling frequency
alejoe91 Aug 2, 2023
96abe6a
Correcly save session level results
alejoe91 Aug 29, 2023
fec2c53
Merge branch 'use-si-functions' of github.com:Jad-Selman/ephys-deepin…
alejoe91 Aug 29, 2023
8067f75
A round of black
alejoe91 Aug 29, 2023
e4d473d
Extend training
alejoe91 Sep 5, 2023
83008fe
Extend unit level and matched unit level results
alejoe91 Sep 15, 2023
05b74f7
wrong argument
alejoe91 Sep 15, 2023
339cc4a
Fixes
alejoe91 Sep 15, 2023
db18cd9
Use scratch folder for waveforms and n_jobs=1 for QM
alejoe91 Sep 15, 2023
8950ad7
Fix data loading in spike sorting pipeline
alejoe91 Sep 15, 2023
9fbb1d9
pre-compute sparsity
alejoe91 Sep 18, 2023
26c50b8
Add filter_option in JSOn and skip NN metrics
alejoe91 Sep 18, 2023
313211a
Cleanup
alejoe91 Sep 18, 2023
84dc31e
Save sparse waveforms in results
alejoe91 Sep 18, 2023
8196427
filter_options -> filter_option
alejoe91 Sep 18, 2023
53754d2
Remove correct folder
alejoe91 Sep 19, 2023
2a4f63b
Add metrics to matched-units csv
alejoe91 Sep 19, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 0 additions & 28 deletions generate_image_inference.py

This file was deleted.

105 changes: 105 additions & 0 deletions pipeline/run_collect_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import warnings

warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore", category=DeprecationWarning)


#### IMPORTS #######
import shutil


from pathlib import Path
import pandas as pd


base_path = Path("..")

data_folder = base_path / "data"
scratch_folder = base_path / "scratch"
results_folder = base_path / "results"


if __name__ == "__main__":
# list all data entries
print("Data folder content:")
for p in data_folder.iterdir():
print(f"\t{p.name}")

# concatenate dataframes
df_session = None
df_units = None
df_matched_units = None

probe_sortings_folders = [p for p in data_folder.iterdir() if p.name.startswith("sorting_") and p.is_dir()]

if len(probe_sortings_folders) > 0:
data_models_folder = data_folder
data_sortings_folder = data_folder
else:
data_model_subfolders = []
for p in data_folder.iterdir():
if p.is_dir() and len([pp for pp in p.iterdir() if "model_" in pp.name and pp.is_dir()]) > 0:
data_model_subfolders.append(p)
data_models_folder = data_model_subfolders[0]

data_sorting_subfolders = []
for p in data_folder.iterdir():
if p.is_dir() and len([pp for pp in p.iterdir() if "sorting_" in pp.name and pp.is_dir()]) > 0:
data_sorting_subfolders.append(p)
data_sortings_folder = data_sorting_subfolders[0]

session_csvs = [p for p in data_sortings_folder.iterdir() if p.name.endswith("sessions.csv")]
unit_csvs = [p for p in data_sortings_folder.iterdir() if p.name.endswith("units.csv") and "matched" not in p.name]
matched_unit_csvs = [p for p in data_sortings_folder.iterdir() if p.name.endswith("matched-units.csv")]

for session_csv in session_csvs:
if df_session is None:
df_session = pd.read_csv(session_csv)
else:
df_session = pd.concat([df_session, pd.read_csv(session_csv)])

for unit_csv in unit_csvs:
if df_units is None:
df_units = pd.read_csv(unit_csv)
else:
df_units = pd.concat([df_units, pd.read_csv(unit_csv)])

for matched_unit_csv in matched_unit_csvs:
if df_matched_units is None:
df_matched_units = pd.read_csv(matched_unit_csv)
else:
df_matched_units = pd.concat([df_matched_units, pd.read_csv(matched_unit_csv)])

# save concatenated dataframes
df_session.to_csv(results_folder / "sessions.csv", index=False)
df_units.to_csv(results_folder / "units.csv", index=False)
df_matched_units.to_csv(results_folder / "matched-units.csv", index=False)

# copy sortings to results folder
sortings_folders = [p for p in data_sortings_folder.iterdir() if "sorting_" in p.name and p.is_dir()]
sortings_output_base_folder = results_folder / "sortings"
sortings_output_base_folder.mkdir(exist_ok=True)

for sorting_folder in sortings_folders:
sorting_folder_split = sorting_folder.name.split("_")
dataset_name = sorting_folder_split[1]
session_name = "_".join(sorting_folder_split[2:-1])
filter_option = sorting_folder_split[-1]
sorting_output_folder = sortings_output_base_folder / dataset_name / session_name / filter_option
sorting_output_folder.mkdir(exist_ok=True, parents=True)
for sorting_subfolder in sorting_folder.iterdir():
shutil.copytree(sorting_subfolder, sorting_output_folder / sorting_subfolder.name)

# copy models to results folder
models_folders = [p for p in data_models_folder.iterdir() if "model_" in p.name and p.is_dir()]
models_output_base_folder = results_folder / "models"
models_output_base_folder.mkdir(exist_ok=True)

for model_folder in models_folders:
model_folder_split = model_folder.name.split("_")
dataset_name = model_folder_split[1]
session_name = "_".join(model_folder_split[2:-1])
filter_option = model_folder_split[-1]
model_output_folder = models_output_base_folder / dataset_name / session_name / filter_option
model_output_folder.parent.mkdir(exist_ok=True, parents=True)
shutil.copytree(model_folder, model_output_folder)
233 changes: 233 additions & 0 deletions pipeline/run_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
import warnings

warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)


#### IMPORTS #######
import os
import sys
import shutil
import json
import numpy as np
from pathlib import Path
import pandas as pd
import time


# SpikeInterface
import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.sorters as ss
import spikeinterface.postprocessing as spost
import spikeinterface.comparison as sc
import spikeinterface.qualitymetrics as sqm

# Tensorflow
import tensorflow as tf


os.environ["OPENBLAS_NUM_THREADS"] = "1"


base_path = Path("..")

##### DEFINE DATASETS AND FOLDERS #######
from sessions import all_sessions_exp, all_sessions_sim

n_jobs = 24

job_kwargs = dict(n_jobs=n_jobs, progress_bar=True, chunk_duration="1s")

data_folder = base_path / "data"
scratch_folder = base_path / "scratch"
results_folder = base_path / "results"


if (data_folder / "ephys-compression-benchmark").is_dir():
DATASET_FOLDER = data_folder / "ephys-compression-benchmark"
all_sessions = all_sessions_exp
data_type = "exp"
elif (data_folder / "MEArec-NP-recordings").is_dir():
DATASET_FOLDER = data_folder / "MEArec-NP-recordings"
all_sessions = all_sessions_sim
data_type = "sim"
else:
raise Exception("Could not find dataset folder")


DEBUG = False
NUM_DEBUG_SESSIONS = 2
DEBUG_DURATION = 20

##### DEFINE PARAMS #####
OVERWRITE = False
USE_GPU = True


# Define training and testing constants (@Jad you can gradually increase this)


FILTER_OPTIONS = ["bp", "hp"] # "hp", "bp", "no"

# DI params
pre_frame = 30
post_frame = 30
pre_post_omission = 1
desired_shape = (192, 2)
# play around with these
inference_n_jobs = 24
inference_chunk_duration = "1s"
inference_predict_workers = 1
inference_memory_gpu = 2000 # MB

di_kwargs = dict(
pre_frame=pre_frame,
post_frame=post_frame,
pre_post_omission=pre_post_omission,
desired_shape=desired_shape,
)

if __name__ == "__main__":
if len(sys.argv) == 5:
if sys.argv[1] == "true":
DEBUG = True
OVERWRITE = True
else:
DEBUG = False
OVERWRITE = False
n_jobs = int(sys.argv[2])
inference_n_jobs = int(sys.argv[3])
inference_predict_workers = int(sys.argv[4])

session_dict = all_sessions
filter_options = FILTER_OPTIONS

json_files = [p for p in data_folder.iterdir() if p.name.endswith(".json")]
if len(json_files) == 1:
print(f"Found {len(json_files)} JSON config")
session_dict = {}
# each json file contains a session to run
json_file = json_files[0]
with open(json_file, "r") as f:
config = json.load(f)
probe = config["probe"]
if probe not in session_dict:
session_dict[probe] = []
session = config["session"]
assert (
session in all_sessions[probe]
), f"{session} is not a valid session. Valid sessions for {probe} are:\n{all_sessions[probe]}"
session_dict[probe].append(session)
if "filter_option" in config:
filter_options = [config["filter_option"]]
else:
filter_options = FILTER_OPTIONS
elif len(json_files) > 1:
print("Only 1 JSON config file allowed, using default sessions")

print(f"Sessions:\n{session_dict}")
print(f"Filter options:\n{filter_options}")
si.set_global_job_kwargs(**job_kwargs)

print(f"Tensorflow GPU status: {tf.config.list_physical_devices('GPU')}")

#### START ####
probe_models_folders = [p for p in data_folder.iterdir() if "model_" in p.name and p.is_dir()]

if len(probe_models_folders) > 0:
data_model_folder = data_folder
else:
data_model_subfolders = []
for p in data_folder.iterdir():
if p.is_dir() and len([pp for pp in p.iterdir() if "model_" in pp.name and pp.is_dir()]) > 0:
data_model_subfolders.append(p)
data_model_folder = data_model_subfolders[0]

for probe, sessions in session_dict.items():
print(f"Dataset {probe}")
for session in sessions:
print(f"\nAnalyzing session {session}\n")
dataset_name, session_name = session.split("/")

if data_type == "exp":
recording = si.load_extractor(DATASET_FOLDER / session)
else:
recording, _ = se.read_mearec(DATASET_FOLDER / session)
session_name = session_name.split(".")[0]

if DEBUG:
recording = recording.frame_slice(
start_frame=0,
end_frame=int(DEBUG_DURATION * recording.sampling_frequency),
)
print(f"\t{recording}")

for filter_option in filter_options:
print(f"\tFilter option: {filter_option}")
recording_name = f"{dataset_name}_{session_name}_{filter_option}"

# apply filter and zscore
if filter_option == "hp":
recording_processed = spre.highpass_filter(recording)
elif filter_option == "bp":
recording_processed = spre.bandpass_filter(recording)
else:
recording_processed = recording

if data_type == "sim":
recording_processed = spre.depth_order(recording_processed)

recording_zscore = spre.zscore(recording_processed)
# This speeds things up a lot
recording_zscore_bin = recording_zscore.save(
folder=scratch_folder / f"recording_zscored_{recording_name}"
)

# train model
model_folder = data_model_folder / f"model_{recording_name}"
model_path = [p for p in model_folder.iterdir() if p.name.endswith("model.h5")][0]
# full inference
output_folder = results_folder / f"deepinterpolated_{recording_name}"
if OVERWRITE and output_folder.is_dir():
shutil.rmtree(output_folder)

if not output_folder.is_dir():
t_start_inference = time.perf_counter()
output_folder.parent.mkdir(exist_ok=True, parents=True)
recording_di = spre.deepinterpolate(
recording_zscore_bin,
model_path=model_path,
pre_frame=pre_frame,
post_frame=post_frame,
pre_post_omission=pre_post_omission,
memory_gpu=inference_memory_gpu,
predict_workers=inference_predict_workers,
use_gpu=USE_GPU,
)
recording_di = recording_di.save(
folder=output_folder,
n_jobs=inference_n_jobs,
chunk_duration=inference_chunk_duration,
)
t_stop_inference = time.perf_counter()
elapsed_time_inference = np.round(t_stop_inference - t_start_inference, 2)
print(f"\t\tElapsed time INFERENCE: {elapsed_time_inference}s")
else:
print("\t\tLoading existing folder")
recording_di = si.load_extractor(output_folder)
# apply inverse z-scoring
inverse_gains = 1 / recording_zscore.gain
inverse_offset = -recording_zscore.offset * inverse_gains
recording_di = spre.scale(recording_di, gain=inverse_gains, offset=inverse_offset, dtype="float")

# save processed json
processed_folder = results_folder / f"processed_{recording_name}"
processed_folder.mkdir(exist_ok=True, parents=True)
recording_processed.dump_to_json(processed_folder / "processed.json", relative_to=results_folder)
recording_di.dump_to_json(processed_folder / f"deepinterpolated.json", relative_to=results_folder)

for json_file in json_files:
shutil.copy(json_file, results_folder)
Loading