Skip to content

Commit

Permalink
Merge pull request #1948 from samuelgarcia/generator
Browse files Browse the repository at this point in the history
Refactor generate.py
  • Loading branch information
alejoe91 authored Sep 1, 2023
2 parents 3817ee0 + 20f5108 commit 23aef27
Show file tree
Hide file tree
Showing 14 changed files with 1,466 additions and 1,124 deletions.
25 changes: 9 additions & 16 deletions src/spikeinterface/comparison/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
BaseSorting,
WaveformExtractor,
NumpySorting,
NpzSortingExtractor,
InjectTemplatesRecording,
)
from spikeinterface.core.core_tools import define_function_from_class
from spikeinterface.core import generate_sorting
from spikeinterface.core.generate import generate_sorting, InjectTemplatesRecording, _ensure_seed


class HybridUnitsRecording(InjectTemplatesRecording):
Expand Down Expand Up @@ -60,6 +58,7 @@ def __init__(
amplitude_std: float = 0.0,
refractory_period_ms: float = 2.0,
injected_sorting_folder: Union[str, Path, None] = None,
seed=None,
):
num_samples = [
parent_recording.get_num_frames(seg_index) for seg_index in range(parent_recording.get_num_segments())
Expand All @@ -80,8 +79,8 @@ def __init__(
num_units=len(templates),
sampling_frequency=fs,
durations=durations,
firing_rate=firing_rate,
refractory_period=refractory_period_ms,
firing_rates=firing_rate,
refractory_period_ms=refractory_period_ms,
)
# save injected sorting if necessary
self.injected_sorting = injected_sorting
Expand All @@ -90,17 +89,10 @@ def __init__(
self.injected_sorting = self.injected_sorting.save(folder=injected_sorting_folder)

if amplitude_factor is None:
amplitude_factor = [
[
np.random.normal(
loc=1.0,
scale=amplitude_std,
size=len(self.injected_sorting.get_unit_spike_train(unit_id, segment_index=seg_index)),
)
for unit_id in self.injected_sorting.unit_ids
]
for seg_index in range(parent_recording.get_num_segments())
]
seed = _ensure_seed(seed)
rng = np.random.default_rng(seed=seed)
num_spikes = self.injected_sorting.to_spike_vector().size
amplitude_factor = rng.normal(loc=1.0, scale=amplitude_std, size=num_spikes)

InjectTemplatesRecording.__init__(
self, self.injected_sorting, templates, nbefore, amplitude_factor, parent_recording, num_samples
Expand All @@ -116,6 +108,7 @@ def __init__(
amplitude_std=amplitude_std,
refractory_period_ms=refractory_period_ms,
injected_sorting_folder=None,
seed=seed,
)


Expand Down
9 changes: 8 additions & 1 deletion src/spikeinterface/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@
inject_some_duplicate_units,
inject_some_split_units,
synthetize_spike_train_bad_isi,
generate_templates,
NoiseGeneratorRecording,
noise_generator_recording,
generate_recording_by_size,
InjectTemplatesRecording,
inject_templates,
generate_ground_truth_recording,
)

# utils to append and concatenate segment (equivalent to OLD MultiRecordingTimeExtractor)
Expand Down Expand Up @@ -109,7 +116,7 @@
)

# templates addition
from .injecttemplates import InjectTemplatesRecording, InjectTemplatesRecordingSegment, inject_templates
# from .injecttemplates import InjectTemplatesRecording, InjectTemplatesRecordingSegment, inject_templates

# template tools
from .template_tools import (
Expand Down
Loading

0 comments on commit 23aef27

Please sign in to comment.