Skip to content

Commit

Permalink
Dealing with black
Browse files Browse the repository at this point in the history
  • Loading branch information
yger committed Apr 26, 2024
2 parents 9ac08c3 + 57b3636 commit 1ad0d6d
Show file tree
Hide file tree
Showing 10 changed files with 104 additions and 91 deletions.
1 change: 1 addition & 0 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def _ensure_seed(seed):
seed = np.random.default_rng(seed=None).integers(0, 2**63)
return seed


def generate_recording(
num_channels: Optional[int] = 2,
sampling_frequency: Optional[float] = 30000.0,
Expand Down
39 changes: 22 additions & 17 deletions src/spikeinterface/curation/auto_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .mergeunitssorting import MergeUnitsSorting


def get_potential_auto_merge(
sorting_analyzer,
minimum_spikes=1000,
Expand All @@ -30,7 +31,7 @@ def get_potential_auto_merge(
firing_contamination_balance=1.5,
extra_outputs=False,
steps=None,
template_metric='l1'
template_metric="l1",
):
"""
Algorithm to find and check potential merges between units.
Expand Down Expand Up @@ -146,7 +147,7 @@ def get_potential_auto_merge(
to_remove = num_spikes < minimum_spikes
pair_mask[to_remove, :] = False
pair_mask[:, to_remove] = False

# STEP 2 : remove contaminated auto corr
if "remove_contaminated" in steps:
contaminations, nb_violations = compute_refrac_period_violations(
Expand All @@ -170,10 +171,10 @@ def get_potential_auto_merge(
)
unit_max_chan = list(unit_max_chan.values())
unit_locations = chan_loc[unit_max_chan, :]

unit_distances = scipy.spatial.distance.cdist(unit_locations, unit_locations, metric="euclidean")
pair_mask = pair_mask & (unit_distances <= maximum_distance_um)

# STEP 4 : potential auto merge by correlogram
if "correlogram" in steps:
correlograms, bins = compute_correlograms(sorting, window_ms=window_ms, bin_ms=bin_ms, method="numba")
Expand Down Expand Up @@ -205,14 +206,18 @@ def get_potential_auto_merge(
templates_ext is not None
), "auto_merge with template_similarity requires a SortingAnalyzer with extension templates"

templates = templates_ext.get_data(outputs='Templates')
templates = templates_ext.get_data(outputs="Templates")
templates = templates.to_sparse(sorting_analyzer.sparsity)

templates_diff = compute_templates_diff(
sorting, templates, num_channels=num_channels, num_shift=num_shift, pair_mask=pair_mask,
template_metric=template_metric
sorting,
templates,
num_channels=num_channels,
num_shift=num_shift,
pair_mask=pair_mask,
template_metric=template_metric,
)

pair_mask = pair_mask & (templates_diff < template_diff_thresh)

# STEP 6 : validate the potential merges with CC increase the contamination quality metrics
Expand Down Expand Up @@ -393,7 +398,7 @@ def get_unit_adaptive_window(auto_corr: np.ndarray, threshold: float):
return win_size


def compute_templates_diff(sorting, templates, num_channels=5, num_shift=5, pair_mask=None, template_metric='l1'):
def compute_templates_diff(sorting, templates, num_channels=5, num_shift=5, pair_mask=None, template_metric="l1"):
"""
Computes normalized template differences.
Expand Down Expand Up @@ -448,25 +453,25 @@ def compute_templates_diff(sorting, templates, num_channels=5, num_shift=5, pair
template2 = template2[:, chan_inds]

num_samples = template1.shape[0]
if template_metric == 'l1':
if template_metric == "l1":
norm = np.sum(np.abs(template1)) + np.sum(np.abs(template2))
elif template_metric == 'l2':
elif template_metric == "l2":
norm = np.sum(template1**2) + np.sum(template2**2)
elif template_metric == 'cosine':
elif template_metric == "cosine":
norm = np.linalg.norm(template1) * np.linalg.norm(template2)
all_shift_diff = []
for shift in range(-num_shift, num_shift + 1):
temp1 = template1[num_shift : num_samples - num_shift, :]
temp2 = template2[num_shift + shift : num_samples - num_shift + shift, :]
if template_metric == 'l1':
if template_metric == "l1":
d = np.sum(np.abs(temp1 - temp2)) / norm
elif template_metric == 'l2':
elif template_metric == "l2":
d = np.linalg.norm(temp1 - temp2) / norm
elif template_metric == 'cosine':
elif template_metric == "cosine":
d = min(1, 1 - np.sum(temp1 * temp2) / norm)
all_shift_diff.append(d)
templates_diff[unit_ind1, unit_ind2] = np.min(all_shift_diff)

return templates_diff


Expand Down
38 changes: 22 additions & 16 deletions src/spikeinterface/sorters/internal/spyking_circus2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
_default_params = {
"general": {"ms_before": 2, "ms_after": 2, "radius_um": 100},
"sparsity": {"method": "ptp", "threshold": 0.25},
"filtering": {"freq_min": 150, "freq_max": 7000, "ftype" : "bessel", "filter_order" : 2},
"filtering": {"freq_min": 150, "freq_max": 7000, "ftype": "bessel", "filter_order": 2},
"detection": {"peak_sign": "neg", "detect_threshold": 4},
"selection": {
"method": "uniform",
Expand All @@ -46,9 +46,14 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
"select_per_channel": False,
"seed": 42,
},
"drift_correction" : {"preset" : "nonrigid_fast_and_accurate"},
"merging" : {"minimum_spikes" : 10, "corr_diff_thresh" : 0.5, "template_metric" : 'cosine',
"censor_correlograms_ms" : 0.4, "num_channels" : 5},
"drift_correction": {"preset": "nonrigid_fast_and_accurate"},
"merging": {
"minimum_spikes": 10,
"corr_diff_thresh": 0.5,
"template_metric": "cosine",
"censor_correlograms_ms": 0.4,
"num_channels": 5,
},
"clustering": {"legacy": True},
"matching": {"method": "circus-omp-svd"},
"apply_preprocessing": True,
Expand All @@ -73,7 +78,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter):
True, one other clustering called circus will be used, similar to the one used in Spyking Circus 1",
"matching": "A dictionary to specify the matching engine used to recover spikes. The method default is circus-omp-svd, but other engines\
can be used",
"merging" : "A dictionary to specify the final merging param to group cells after template matching (get_potential_auto_merge)",
"merging": "A dictionary to specify the final merging param to group cells after template matching (get_potential_auto_merge)",
"motion_correction": "A dictionary to be provided if motion correction has to be performed (dense probe only)",
"apply_preprocessing": "Boolean to specify whether circus 2 should preprocess the recording or not. If yes, then high_pass filtering + common\
median reference + zscore",
Expand Down Expand Up @@ -124,19 +129,19 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
else:
recording_f = recording
recording_f.annotate(is_filtered=True)

valid_geometry = check_probe_for_drift_correction(recording_f)
if params["drift_correction"] is not None:
if not valid_geometry:
print("Geometry of the probe does not allow 1D drift correction")
else:
print("Motion correction activated (probe geometry compatible)")
motion_folder = sorter_output_folder / "motion"
params['drift_correction'].update({'folder' : motion_folder})
recording_f = correct_motion(recording_f, **params['drift_correction'])
params["drift_correction"].update({"folder": motion_folder})
recording_f = correct_motion(recording_f, **params["drift_correction"])

## We need to whiten before the template matching step, to boost the results
recording_w = whiten(recording_f, mode='local', radius_um=radius_um, dtype="float32", regularize=True)
recording_w = whiten(recording_f, mode="local", radius_um=radius_um, dtype="float32", regularize=True)

noise_levels = get_noise_levels(recording_w, return_scaled=False)

Expand All @@ -150,9 +155,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
## Then, we are detecting peaks with a locally_exclusive method
detection_params = params["detection"].copy()
detection_params.update(job_kwargs)
detection_params["radius_um"] = detection_params.get('radius_um', 50)
detection_params["exclude_sweep_ms"] = detection_params.get('exclude_sweep_ms', 0.5)

detection_params["radius_um"] = detection_params.get("radius_um", 50)
detection_params["exclude_sweep_ms"] = detection_params.get("exclude_sweep_ms", 0.5)
detection_params["noise_levels"] = noise_levels

fs = recording_w.get_sampling_frequency()
Expand Down Expand Up @@ -292,14 +297,15 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
sorting_folder = sorter_output_folder / "sorting"
if sorting_folder.exists():
shutil.rmtree(sorting_folder)

merging_params = params["merging"].copy()

if len(merging_params) > 0:
if params['drift_correction']:
if params["drift_correction"]:
from spikeinterface.preprocessing.motion import load_motion_info

motion_info = load_motion_info(motion_folder)
merging_params['maximum_distance_um'] = max(50, 2*np.abs(motion_info['motion']).max())
merging_params["maximum_distance_um"] = max(50, 2 * np.abs(motion_info["motion"]).max())

# peak_sign = params['detection'].get('peak_sign', 'neg')
# best_amplitudes = get_template_extremum_amplitude(templates, peak_sign=peak_sign)
Expand All @@ -313,7 +319,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
if curation_folder.exists():
shutil.rmtree(curation_folder)
sorting.save(folder=curation_folder)
#np.save(fitting_folder / "amplitudes", guessed_amplitudes)
# np.save(fitting_folder / "amplitudes", guessed_amplitudes)

sorting = final_cleaning_circus(recording_w, sorting, templates, **merging_params)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -300,8 +300,8 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs
result = self.get_result(key)
scores = result["gt_comparison"].agreement_scores

positions = result["sliced_gt_sorting"].get_property('gt_unit_locations')
#positions = self.datasets[key[1]][1].get_property("gt_unit_locations")
positions = result["sliced_gt_sorting"].get_property("gt_unit_locations")
# positions = self.datasets[key[1]][1].get_property("gt_unit_locations")
depth = positions[:, 1]

analyzer = self.get_sorting_analyzer(key)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@ def create_sorting_analyzer_gt(self, case_keys=None, return_scaled=True, random_
sorting_analyzer.compute("random_spikes", **random_params)
sorting_analyzer.compute("templates", **job_kwargs)
sorting_analyzer.compute("noise_levels")


def get_sorting_analyzer(self, case_key=None, dataset_key=None):
if case_key is not None:
Expand Down
22 changes: 12 additions & 10 deletions src/spikeinterface/sortingcomponents/clustering/circus.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,16 @@ class CircusClustering:
"allow_single_cluster": True,
"core_dist_n_jobs": -1,
"cluster_selection_method": "eom",
#"cluster_selection_epsilon" : 5 ## To be optimized
# "cluster_selection_epsilon" : 5 ## To be optimized
},
"cleaning_kwargs": {},
"waveforms": {"ms_before": 2, "ms_after": 2},
"sparsity": {"method": "ptp", "threshold": 0.25},
"recursive_kwargs" : {"recursive" : True,
"recursive_depth" : 3,
"returns_split_count" : True,
},
"recursive_kwargs": {
"recursive": True,
"recursive_depth": 3,
"returns_split_count": True,
},
"radius_um": 100,
"n_svd": [5, 2],
"ms_before": 2,
Expand Down Expand Up @@ -143,14 +144,14 @@ def main_function(cls, recording, peaks, params):
nb_clusters = 0
for c in np.unique(peaks["channel_index"]):
mask = peaks["channel_index"] == c
sub_data = all_pc_data[mask]
sub_data = all_pc_data[mask]
sub_data = sub_data.reshape(len(sub_data), -1)

if all_pc_data.shape[1] > params["n_svd"][1]:
tsvd = TruncatedSVD(params["n_svd"][1])
else:
tsvd = TruncatedSVD(all_pc_data.shape[1])

hdbscan_data = tsvd.fit_transform(sub_data)
try:
clustering = hdbscan.hdbscan(hdbscan_data, **d["hdbscan_kwargs"])
Expand All @@ -163,7 +164,7 @@ def main_function(cls, recording, peaks, params):
peak_labels[mask] = local_labels
nb_clusters += len(np.unique(local_labels[valid_clusters]))
else:

features_folder = tmp_folder / "tsvd_features"
features_folder.mkdir(exist_ok=True)

Expand All @@ -181,11 +182,12 @@ def main_function(cls, recording, peaks, params):
sparse_mask = node1.neighbours_mask
neighbours_mask = get_channel_distances(recording) < radius_um

#np.save(features_folder / "sparse_mask.npy", sparse_mask)
# np.save(features_folder / "sparse_mask.npy", sparse_mask)
np.save(features_folder / "peaks.npy", peaks)

original_labels = peaks["channel_index"]
from spikeinterface.sortingcomponents.clustering.split import split_clusters

peak_labels, _ = split_clusters(
original_labels,
recording,
Expand All @@ -199,7 +201,7 @@ def main_function(cls, recording, peaks, params):
min_size_split=50,
clusterer_kwargs=d["hdbscan_kwargs"],
n_pca_features=params["n_svd"][1],
scale_n_pca_by_depth=True
scale_n_pca_by_depth=True,
),
**params["recursive_kwargs"],
**job_kwargs,
Expand Down
Loading

0 comments on commit 1ad0d6d

Please sign in to comment.