Skip to content

Commit

Permalink
Merge pull request #2014 from alejoe91/speed-up-amp-scalings
Browse files Browse the repository at this point in the history
Extend common postprocessing tests to spikes at borders
  • Loading branch information
samuelgarcia authored Sep 29, 2023
2 parents 89affa7 + 537ebbb commit 225d9d1
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 54 deletions.
28 changes: 28 additions & 0 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ def generate_sorting(
firing_rates=3.0,
empty_units=None,
refractory_period_ms=3.0, # in ms
add_spikes_on_borders=False,
num_spikes_per_border=3,
border_size_samples=20,
seed=None,
):
"""
Expand All @@ -142,6 +145,12 @@ def generate_sorting(
List of units that will have no spikes. (used for testing mainly).
refractory_period_ms : float, default: 3.0
The refractory period in ms
add_spikes_on_borders : bool, default: False
If True, spikes will be added close to the borders of the segments.
num_spikes_per_border : int, default: 3
The number of spikes to add close to the borders of the segments.
border_size_samples : int, default: 20
The size of the border in samples to add border spikes.
seed : int, default: None
The random seed
Expand All @@ -151,11 +160,13 @@ def generate_sorting(
The sorting object
"""
seed = _ensure_seed(seed)
rng = np.random.default_rng(seed)
num_segments = len(durations)
unit_ids = np.arange(num_units)

spikes = []
for segment_index in range(num_segments):
num_samples = int(sampling_frequency * durations[segment_index])
times, labels = synthesize_random_firings(
num_units=num_units,
sampling_frequency=sampling_frequency,
Expand All @@ -175,7 +186,23 @@ def generate_sorting(
spikes_in_seg["unit_index"] = labels
spikes_in_seg["segment_index"] = segment_index
spikes.append(spikes_in_seg)

if add_spikes_on_borders:
spikes_on_borders = np.zeros(2 * num_spikes_per_border, dtype=minimum_spike_dtype)
spikes_on_borders["segment_index"] = segment_index
spikes_on_borders["unit_index"] = rng.choice(num_units, size=2 * num_spikes_per_border, replace=True)
# at start
spikes_on_borders["sample_index"][:num_spikes_per_border] = rng.integers(
0, border_size_samples, num_spikes_per_border
)
# at end
spikes_on_borders["sample_index"][num_spikes_per_border:] = rng.integers(
num_samples - border_size_samples, num_samples, num_spikes_per_border
)
spikes.append(spikes_on_borders)

spikes = np.concatenate(spikes)
spikes = spikes[np.lexsort((spikes["sample_index"], spikes["segment_index"]))]

sorting = NumpySorting(spikes, sampling_frequency, unit_ids)

Expand Down Expand Up @@ -596,6 +623,7 @@ def __init__(
dtype = np.dtype(dtype).name # Cast to string for serialization
if dtype not in ("float32", "float64"):
raise ValueError(f"'dtype' must be 'float32' or 'float64' but is {dtype}")
assert strategy in ("tile_pregenerated", "on_the_fly"), "'strategy' must be 'tile_pregenerated' or 'on_the_fly'"

BaseRecording.__init__(self, sampling_frequency=sampling_frequency, channel_ids=channel_ids, dtype=dtype)

Expand Down
36 changes: 33 additions & 3 deletions src/spikeinterface/core/tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,44 @@


def test_generate_recording():
# TODO even this is extenssivly tested in all other function
# TODO even this is extensively tested in all other functions
pass


def test_generate_sorting():
# TODO even this is extenssivly tested in all other function
# TODO even this is extensively tested in all other functions
pass


def test_generate_sorting_with_spikes_on_borders():
num_spikes_on_borders = 10
border_size_samples = 10
segment_duration = 10
for nseg in [1, 2, 3]:
sorting = generate_sorting(
durations=[segment_duration] * nseg,
sampling_frequency=30000,
num_units=10,
add_spikes_on_borders=True,
num_spikes_per_border=num_spikes_on_borders,
border_size_samples=border_size_samples,
)
# check that segments are correctly sorted
all_spikes = sorting.to_spike_vector()
np.testing.assert_array_equal(all_spikes["segment_index"], np.sort(all_spikes["segment_index"]))

spikes = sorting.to_spike_vector(concatenated=False)
# at least num_border spikes at borders for all segments
for spikes_in_segment in spikes:
# check that sample indices are correctly sorted within segments
np.testing.assert_array_equal(spikes_in_segment["sample_index"], np.sort(spikes_in_segment["sample_index"]))
num_samples = int(segment_duration * 30000)
assert np.sum(spikes_in_segment["sample_index"] < border_size_samples) >= num_spikes_on_borders
assert (
np.sum(spikes_in_segment["sample_index"] >= num_samples - border_size_samples) >= num_spikes_on_borders
)


def measure_memory_allocation(measure_in_process: bool = True) -> float:
"""
A local utility to measure memory allocation at a specific point in time.
Expand Down Expand Up @@ -399,7 +428,7 @@ def test_generate_ground_truth_recording():
if __name__ == "__main__":
strategy = "tile_pregenerated"
# strategy = "on_the_fly"
test_noise_generator_memory()
# test_noise_generator_memory()
# test_noise_generator_under_giga()
# test_noise_generator_correct_shape(strategy)
# test_noise_generator_consistency_across_calls(strategy, 0, 5)
Expand All @@ -410,3 +439,4 @@ def test_generate_ground_truth_recording():
# test_generate_templates()
# test_inject_templates()
# test_generate_ground_truth_recording()
test_generate_sorting_with_spikes_on_borders()
70 changes: 39 additions & 31 deletions src/spikeinterface/postprocessing/amplitude_scalings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class AmplitudeScalingsCalculator(BaseWaveformExtractorExtension):
"""

extension_name = "amplitude_scalings"
handle_sparsity = True

def __init__(self, waveform_extractor):
BaseWaveformExtractorExtension.__init__(self, waveform_extractor)
Expand Down Expand Up @@ -68,7 +69,6 @@ def _run(self, **job_kwargs):
delta_collision_samples = int(delta_collision_ms / 1000 * we.sampling_frequency)

return_scaled = we._params["return_scaled"]
unit_ids = we.unit_ids

if ms_before is not None:
assert (
Expand All @@ -82,18 +82,22 @@ def _run(self, **job_kwargs):
cut_out_before = int(ms_before / 1000 * we.sampling_frequency) if ms_before is not None else nbefore
cut_out_after = int(ms_after / 1000 * we.sampling_frequency) if ms_after is not None else nafter

if we.is_sparse():
if we.is_sparse() and self._params["sparsity"] is None:
sparsity = we.sparsity
elif self._params["sparsity"] is not None:
elif we.is_sparse() and self._params["sparsity"] is not None:
sparsity = self._params["sparsity"]
# assert provided sparsity is sparser than the one in the waveform extractor
waveform_sparsity = we.sparsity
assert np.all(
np.sum(waveform_sparsity.mask, 1) - np.sum(sparsity.mask, 1) > 0
), "The provided sparsity needs to be sparser than the one in the waveform extractor!"
elif not we.is_sparse() and self._params["sparsity"] is not None:
sparsity = self._params["sparsity"]
else:
if self._params["max_dense_channels"] is not None:
assert recording.get_num_channels() <= self._params["max_dense_channels"], ""
sparsity = ChannelSparsity.create_dense(we)
sparsity_inds = sparsity.unit_id_to_channel_indices

# easier to use in chunk function as spikes use unit_index instead o id
unit_inds_to_channel_indices = {unit_ind: sparsity_inds[unit_id] for unit_ind, unit_id in enumerate(unit_ids)}
sparsity_mask = sparsity.mask
all_templates = we.get_all_templates()

# precompute segment slice
Expand All @@ -112,7 +116,7 @@ def _run(self, **job_kwargs):
self.spikes,
all_templates,
segment_slices,
unit_inds_to_channel_indices,
sparsity_mask,
nbefore,
nafter,
cut_out_before,
Expand Down Expand Up @@ -261,7 +265,7 @@ def _init_worker_amplitude_scalings(
spikes,
all_templates,
segment_slices,
unit_inds_to_channel_indices,
sparsity_mask,
nbefore,
nafter,
cut_out_before,
Expand All @@ -281,7 +285,7 @@ def _init_worker_amplitude_scalings(
worker_ctx["cut_out_before"] = cut_out_before
worker_ctx["cut_out_after"] = cut_out_after
worker_ctx["return_scaled"] = return_scaled
worker_ctx["unit_inds_to_channel_indices"] = unit_inds_to_channel_indices
worker_ctx["sparsity_mask"] = sparsity_mask
worker_ctx["handle_collisions"] = handle_collisions
worker_ctx["delta_collision_samples"] = delta_collision_samples

Expand All @@ -305,7 +309,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx)
recording = worker_ctx["recording"]
all_templates = worker_ctx["all_templates"]
segment_slices = worker_ctx["segment_slices"]
unit_inds_to_channel_indices = worker_ctx["unit_inds_to_channel_indices"]
sparsity_mask = worker_ctx["sparsity_mask"]
nbefore = worker_ctx["nbefore"]
cut_out_before = worker_ctx["cut_out_before"]
cut_out_after = worker_ctx["cut_out_after"]
Expand Down Expand Up @@ -338,7 +342,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx)
)
local_spikes_w_margin = spikes_in_segment[i0_margin:i1_margin]
collisions_local = find_collisions(
local_spikes, local_spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices
local_spikes, local_spikes_w_margin, delta_collision_samples, sparsity_mask
)
else:
collisions_local = {}
Expand All @@ -353,7 +357,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx)
continue
unit_index = spike["unit_index"]
sample_index = spike["sample_index"]
sparse_indices = unit_inds_to_channel_indices[unit_index]
(sparse_indices,) = np.nonzero(sparsity_mask[unit_index])
template = all_templates[unit_index][:, sparse_indices]
template = template[nbefore - cut_out_before : nbefore + cut_out_after]
sample_centered = sample_index - start_frame
Expand All @@ -364,7 +368,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx)
template = template[cut_out_before - sample_index :]
elif sample_index + cut_out_after > end_frame + right:
local_waveform = traces_with_margin[cut_out_start:, sparse_indices]
template = template[: -(sample_index + cut_out_after - end_frame)]
template = template[: -(sample_index + cut_out_after - (end_frame + right))]
else:
local_waveform = traces_with_margin[cut_out_start:cut_out_end, sparse_indices]
assert template.shape == local_waveform.shape
Expand Down Expand Up @@ -392,7 +396,7 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx)
right,
nbefore,
all_templates,
unit_inds_to_channel_indices,
sparsity_mask,
cut_out_before,
cut_out_after,
)
Expand All @@ -409,14 +413,14 @@ def _amplitude_scalings_chunk(segment_index, start_frame, end_frame, worker_ctx)


### Collision handling ###
def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j):
def _are_unit_indices_overlapping(sparsity_mask, i, j):
"""
Returns True if the unit indices i and j are overlapping, False otherwise
Parameters
----------
unit_inds_to_channel_indices: dict
A dictionary mapping unit indices to channel indices
sparsity_mask: boolean mask
The sparsity mask
i: int
The first unit index
j: int
Expand All @@ -427,13 +431,13 @@ def _are_unit_indices_overlapping(unit_inds_to_channel_indices, i, j):
bool
True if the unit indices i and j are overlapping, False otherwise
"""
if len(np.intersect1d(unit_inds_to_channel_indices[i], unit_inds_to_channel_indices[j])) > 0:
if np.any(sparsity_mask[i] & sparsity_mask[j]):
return True
else:
return False


def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_to_channel_indices):
def find_collisions(spikes, spikes_w_margin, delta_collision_samples, sparsity_mask):
"""
Finds the collisions between spikes.
Expand All @@ -445,8 +449,8 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_
An array of spikes within the added margin
delta_collision_samples: int
The maximum number of samples between two spikes to consider them as overlapping
unit_inds_to_channel_indices: dict
A dictionary mapping unit indices to channel indices
sparsity_mask: boolean mask
The sparsity mask
Returns
-------
Expand Down Expand Up @@ -476,7 +480,7 @@ def find_collisions(spikes, spikes_w_margin, delta_collision_samples, unit_inds_
# find the overlapping spikes in space as well
for possible_overlapping_spike_index in possible_overlapping_spike_indices:
if _are_unit_indices_overlapping(
unit_inds_to_channel_indices,
sparsity_mask,
spike["unit_index"],
spikes_w_margin[possible_overlapping_spike_index]["unit_index"],
):
Expand All @@ -497,7 +501,7 @@ def fit_collision(
right,
nbefore,
all_templates,
unit_inds_to_channel_indices,
sparsity_mask,
cut_out_before,
cut_out_after,
):
Expand All @@ -524,8 +528,8 @@ def fit_collision(
The number of samples before the spike to consider for the fit.
all_templates: np.ndarray
A numpy array of shape (n_units, n_samples, n_channels) containing the templates.
unit_inds_to_channel_indices: dict
A dictionary mapping unit indices to channel indices.
sparsity_mask: boolean mask
The sparsity mask
cut_out_before: int
The number of samples to cut out before the spike.
cut_out_after: int
Expand All @@ -543,14 +547,16 @@ def fit_collision(
sample_last_centered = np.max(collision["sample_index"]) - (start_frame - left)

# construct sparsity as union between units' sparsity
sparse_indices = np.array([], dtype="int")
common_sparse_mask = np.zeros(sparsity_mask.shape[1], dtype="int")
for spike in collision:
sparse_indices_i = unit_inds_to_channel_indices[spike["unit_index"]]
sparse_indices = np.union1d(sparse_indices, sparse_indices_i)
mask_i = sparsity_mask[spike["unit_index"]]
common_sparse_mask = np.logical_or(common_sparse_mask, mask_i)
(sparse_indices,) = np.nonzero(common_sparse_mask)

local_waveform_start = max(0, sample_first_centered - cut_out_before)
local_waveform_end = min(traces_with_margin.shape[0], sample_last_centered + cut_out_after)
local_waveform = traces_with_margin[local_waveform_start:local_waveform_end, sparse_indices]
num_samples_local_waveform = local_waveform.shape[0]

y = local_waveform.T.flatten()
X = np.zeros((len(y), len(collision)))
Expand All @@ -563,8 +569,10 @@ def fit_collision(
# deal with borders
if sample_centered - cut_out_before < 0:
full_template[: sample_centered + cut_out_after] = template_cut[cut_out_before - sample_centered :]
elif sample_centered + cut_out_after > end_frame + right:
full_template[sample_centered - cut_out_before :] = template_cut[: -cut_out_after - (end_frame + right)]
elif sample_centered + cut_out_after > num_samples_local_waveform:
full_template[sample_centered - cut_out_before :] = template_cut[
: -(cut_out_after + sample_centered - num_samples_local_waveform)
]
else:
full_template[sample_centered - cut_out_before : sample_centered + cut_out_after] = template_cut
X[:, i] = full_template.T.flatten()
Expand Down
Loading

0 comments on commit 225d9d1

Please sign in to comment.