Skip to content

Commit

Permalink
Change to None, add some ifs
Browse files Browse the repository at this point in the history
  • Loading branch information
chrishalcrow committed Apr 15, 2024
1 parent 66077e2 commit 11d5d39
Showing 1 changed file with 24 additions and 16 deletions.
40 changes: 24 additions & 16 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,10 +848,10 @@ def clean_refractory_period(times, refractory_period):
def _add_spikes_to_spiketrain(
spike_indices,
spike_labels,
segment_indices=[],
segment_indices=None,
added_spikes_indices=None,
added_spikes_labels=None,
added_segment_indices=[],
added_segment_indices=None,
replace=False,
seed=None,
):
Expand Down Expand Up @@ -886,42 +886,50 @@ def _add_spikes_to_spiketrain(
"""

if added_spikes_indices is None:
if segment_indices is None:
return spike_indices, spike_labels
else:
return spike_indices, spike_labels, segment_indices

# check lengths are consistent
assert len(spike_indices) == len(spike_labels), "Length of spike indices and labels are not equal"
assert (len(segment_indices) == 0) or (
assert (segment_indices is None) or (
len(spike_indices) == len(segment_indices)
), "Length of spike indices and segments are not equal"
assert len(added_spikes_indices) == len(
added_spikes_labels
), "Length of added spike indices and labels are not equal"
assert (len(added_segment_indices) == 0) or (
assert (added_spikes_indices is None) or len(added_spikes_indices) == len(
added_spikes_labels
), "Length of added spike indices and labels are not equal"
assert (added_segment_indices is None) or (
len(added_spikes_indices) == len(added_segment_indices)
), "Length of added spike indices and segments are not equal"
assert (segment_indices is None and added_segment_indices is None) or (
segment_indices is not None and added_segment_indices is not None
), "Existing and added segment indices are inconsistent. Possibly one is non-zero."

new_spike_indices = np.array(spike_indices)
new_spike_labels = np.array(spike_labels)
new_spike_segments = np.array(segment_indices)
if segment_indices is not None:
new_segment_indices = np.array(segment_indices)

rng = np.random.default_rng(seed=seed)

if replace:
replacement_indices = rng.choice(len(spike_indices), len(added_spikes_indices), replace=False)
new_spike_indices[replacement_indices] = added_spikes_indices
new_spike_labels[replacement_indices] = added_spikes_labels
if len(segment_indices) != 0:
print(new_spike_segments[replacement_indices])
print(added_segment_indices)
new_spike_segments[replacement_indices] = added_segment_indices
if added_segment_indices is not None:
new_segment_indices[replacement_indices] = added_segment_indices
else:
new_spike_indices = np.concatenate((new_spike_indices, added_spikes_indices))
new_spike_labels = np.concatenate((new_spike_labels, added_spikes_labels))
if len(segment_indices) != 0:
new_spike_segments = np.concatenate((new_spike_segments, added_segment_indices))
if added_segment_indices is not None:
new_segment_indices = np.concatenate((new_segment_indices, added_segment_indices))

if len(segment_indices) == 0:
if segment_indices is None:
return new_spike_indices, new_spike_labels
else:
return new_spike_indices, new_spike_labels, new_spike_segments
return new_spike_indices, new_spike_labels, new_segment_indices


def inject_some_duplicate_units(sorting, num=4, max_shift=5, ratio=None, seed=None):
Expand Down

0 comments on commit 11d5d39

Please sign in to comment.