Skip to content

Commit

Permalink
Use existing sparsity for unit location + add location with max channel
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelgarcia committed Oct 15, 2024
1 parent 0ae32e7 commit 49c7a92
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 4 deletions.
67 changes: 63 additions & 4 deletions src/spikeinterface/postprocessing/localization_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,12 @@ def compute_monopolar_triangulation(
assert feature in ["ptp", "energy", "peak_voltage"], f"{feature} is not a valid feature"

contact_locations = sorting_analyzer_or_templates.get_channel_locations()

if sorting_analyzer_or_templates.sparsity is None:
sparsity = compute_sparsity(sorting_analyzer_or_templates, method="radius", radius_um=radius_um)
else:
sparsity = sorting_analyzer_or_templates.sparsity

sparsity = compute_sparsity(sorting_analyzer_or_templates, method="radius", radius_um=radius_um)
templates = get_dense_templates_array(
sorting_analyzer_or_templates, return_scaled=get_return_scaled(sorting_analyzer_or_templates)
)
Expand Down Expand Up @@ -157,9 +161,13 @@ def compute_center_of_mass(

assert feature in ["ptp", "mean", "energy", "peak_voltage"], f"{feature} is not a valid feature"

sparsity = compute_sparsity(
sorting_analyzer_or_templates, peak_sign=peak_sign, method="radius", radius_um=radius_um
)
if sorting_analyzer_or_templates.sparsity is None:
sparsity = compute_sparsity(
sorting_analyzer_or_templates, peak_sign=peak_sign, method="radius", radius_um=radius_um
)
else:
sparsity = sorting_analyzer_or_templates.sparsity

templates = get_dense_templates_array(
sorting_analyzer_or_templates, return_scaled=get_return_scaled(sorting_analyzer_or_templates)
)
Expand Down Expand Up @@ -650,8 +658,59 @@ def get_convolution_weights(
enforce_decrease_shells = numba.jit(enforce_decrease_shells_data, nopython=True)



def compute_location_max_channel(
templates_or_sorting_analyzer: SortingAnalyzer | Templates,
unit_ids=None,
peak_sign: "neg" | "pos" | "both" = "neg",
mode: "extremum" | "at_index" | "peak_to_peak" = "extremum",
) -> np.ndarray:
"""
Localize a unit using max channel.
This use inetrnally get_template_extremum_channel()
Parameters
----------
templates_or_sorting_analyzer : SortingAnalyzer | Templates
A SortingAnalyzer or Templates object
unit_ids: str | int | None
A list of unit_id to restrict the computation
peak_sign : "neg" | "pos" | "both"
Sign of the template to find extremum channels
mode : "extremum" | "at_index" | "peak_to_peak", default: "at_index"
Where the amplitude is computed
* "extremum" : take the peak value (max or min depending on `peak_sign`)
* "at_index" : take value at `nbefore` index
* "peak_to_peak" : take the peak-to-peak amplitude
Returns
-------
unit_location: np.ndarray
2d
"""
extremum_channels_index = get_template_extremum_channel(
templates_or_sorting_analyzer,
peak_sign=peak_sign,
mode=mode,
outputs="index"
)
contact_locations = templates_or_sorting_analyzer.get_channel_locations()
if unit_ids is None:
unit_ids = templates_or_sorting_analyzer.unit_ids
else:
unit_ids = np.asarray(unit_ids)
unit_location = np.zeros((unit_ids.size, 2), dtype="float32")
for i, unit_id in enumerate(unit_ids):
unit_location[i, :] = contact_locations[extremum_channels_index[unit_id]]

return unit_location


_unit_location_methods = {
"center_of_mass": compute_center_of_mass,
"grid_convolution": compute_grid_convolution,
"monopolar_triangulation": compute_monopolar_triangulation,
"max_channel": compute_location_max_channel,
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class TestUnitLocationsExtension(AnalyzerExtensionCommonTestSuite):
dict(method="grid_convolution", radius_um=150, weight_method={"mode": "gaussian_2d"}),
dict(method="monopolar_triangulation", radius_um=150),
dict(method="monopolar_triangulation", radius_um=150, optimizer="minimize_with_log_penality"),
dict(method="max_channel"),
],
)
def test_extension(self, params):
Expand Down

0 comments on commit 49c7a92

Please sign in to comment.