Skip to content

Commit

Permalink
Merge pull request #1970 from h-mayorquin/change_default_in_generate_…
Browse files Browse the repository at this point in the history
…recording

Some additions to generate.py after #1948
  • Loading branch information
alejoe91 authored Sep 13, 2023
2 parents 5ee4563 + ac65bc5 commit 35a9a1c
Showing 1 changed file with 31 additions and 26 deletions.
57 changes: 31 additions & 26 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
import numpy as np
from typing import Union, Optional, List, Literal
import warnings


from .numpyextractors import NumpyRecording, NumpySorting
Expand Down Expand Up @@ -31,7 +32,7 @@ def generate_recording(
set_probe: Optional[bool] = True,
ndim: Optional[int] = 2,
seed: Optional[int] = None,
mode: Literal["lazy", "legacy"] = "legacy",
mode: Literal["lazy", "legacy"] = "lazy",
) -> BaseRecording:
"""
Generate a recording object.
Expand All @@ -51,10 +52,10 @@ def generate_recording(
The number of dimensions of the probe, by default 2. Set to 3 to make 3 dimensional probes.
seed : Optional[int]
A seed for the np.ramdom.default_rng function
mode: str ["lazy", "legacy"] Default "legacy".
mode: str ["lazy", "legacy"] Default "lazy".
"legacy": generate a NumpyRecording with white noise.
This mode is kept for backward compatibility and will be deprecated in next release.
"lazy": return a NoiseGeneratorRecording
This mode is kept for backward compatibility and will be deprecated version 0.100.0.
"lazy": return a NoiseGeneratorRecording instance.
Returns
-------
Expand All @@ -64,6 +65,10 @@ def generate_recording(
seed = _ensure_seed(seed)

if mode == "legacy":
warnings.warn(
"generate_recording() : mode='legacy' will be deprecated in version 0.100.0. Use mode='lazy' instead.",
DeprecationWarning,
)
recording = _generate_recording_legacy(num_channels, sampling_frequency, durations, seed)
elif mode == "lazy":
recording = NoiseGeneratorRecording(
Expand Down Expand Up @@ -538,7 +543,7 @@ def synthetize_spike_train_bad_isi(duration, baseline_rate, num_violations, viol

class NoiseGeneratorRecording(BaseRecording):
"""
A lazy recording that generates random samples if and only if `get_traces` is called.
A lazy recording that generates white noise samples if and only if `get_traces` is called.
This done by tiling small noise chunk.
Expand All @@ -555,7 +560,7 @@ class NoiseGeneratorRecording(BaseRecording):
The sampling frequency of the recorder.
durations : List[float]
The durations of each segment in seconds. Note that the length of this list is the number of segments.
noise_level: float, default 5:
noise_level: float, default 1:
Std of the white noise
dtype : Optional[Union[np.dtype, str]], default='float32'
The dtype of the recording. Note that only np.float32 and np.float64 are supported.
Expand All @@ -581,7 +586,7 @@ def __init__(
num_channels: int,
sampling_frequency: float,
durations: List[float],
noise_level: float = 5.0,
noise_level: float = 1.0,
dtype: Optional[Union[np.dtype, str]] = "float32",
seed: Optional[int] = None,
strategy: Literal["tile_pregenerated", "on_the_fly"] = "tile_pregenerated",
Expand Down Expand Up @@ -647,7 +652,7 @@ def __init__(
if self.strategy == "tile_pregenerated":
rng = np.random.default_rng(seed=self.seed)
self.noise_block = (
rng.standard_normal(size=(self.noise_block_size, self.num_channels)).astype(self.dtype) * noise_level
rng.standard_normal(size=(self.noise_block_size, self.num_channels), dtype=self.dtype) * noise_level
)
elif self.strategy == "on_the_fly":
pass
Expand All @@ -664,35 +669,35 @@ def get_traces(
start_frame = 0 if start_frame is None else max(start_frame, 0)
end_frame = self.num_samples if end_frame is None else min(end_frame, self.num_samples)

start_frame_mod = start_frame % self.noise_block_size
end_frame_mod = end_frame % self.noise_block_size
start_frame_within_block = start_frame % self.noise_block_size
end_frame_within_block = end_frame % self.noise_block_size
num_samples = end_frame - start_frame

traces = np.empty(shape=(num_samples, self.num_channels), dtype=self.dtype)

start_block_index = start_frame // self.noise_block_size
end_block_index = end_frame // self.noise_block_size
first_block_index = start_frame // self.noise_block_size
last_block_index = end_frame // self.noise_block_size

pos = 0
for block_index in range(start_block_index, end_block_index + 1):
for block_index in range(first_block_index, last_block_index + 1):
if self.strategy == "tile_pregenerated":
noise_block = self.noise_block
elif self.strategy == "on_the_fly":
rng = np.random.default_rng(seed=(self.seed, block_index))
noise_block = rng.standard_normal(size=(self.noise_block_size, self.num_channels)).astype(self.dtype)
noise_block = rng.standard_normal(size=(self.noise_block_size, self.num_channels), dtype=self.dtype)
noise_block *= self.noise_level

if block_index == start_block_index:
if start_block_index != end_block_index:
end_first_block = self.noise_block_size - start_frame_mod
traces[:end_first_block] = noise_block[start_frame_mod:]
if block_index == first_block_index:
if first_block_index != last_block_index:
end_first_block = self.noise_block_size - start_frame_within_block
traces[:end_first_block] = noise_block[start_frame_within_block:]
pos += end_first_block
else:
# special case when unique block
traces[:] = noise_block[start_frame_mod : start_frame_mod + traces.shape[0]]
elif block_index == end_block_index:
if end_frame_mod > 0:
traces[pos:] = noise_block[:end_frame_mod]
traces[:] = noise_block[start_frame_within_block : start_frame_within_block + num_samples]
elif block_index == last_block_index:
if end_frame_within_block > 0:
traces[pos:] = noise_block[:end_frame_within_block]
else:
traces[pos : pos + self.noise_block_size] = noise_block
pos += self.noise_block_size
Expand All @@ -710,7 +715,7 @@ def get_traces(

def generate_recording_by_size(
full_traces_size_GiB: float,
num_channels: int = 1024,
num_channels: int = 384,
seed: Optional[int] = None,
strategy: Literal["tile_pregenerated", "on_the_fly"] = "tile_pregenerated",
) -> NoiseGeneratorRecording:
Expand All @@ -719,15 +724,15 @@ def generate_recording_by_size(
This is a convenience wrapper around the NoiseGeneratorRecording class where only
the size in GiB (NOT GB!) is specified.
It is generated with 1024 channels and a sampling frequency of 1 Hz. The duration is manipulted to
It is generated with 384 channels and a sampling frequency of 1 Hz. The duration is manipulted to
produced the desired size.
Seee GeneratorRecording for more details.
Parameters
----------
full_traces_size_GiB : float
The size in gibibyte (GiB) of the recording.
The size in gigabytes (GiB) of the recording.
num_channels: int
Number of channels.
seed : int, optional
Expand All @@ -740,7 +745,7 @@ def generate_recording_by_size(

dtype = np.dtype("float32")
sampling_frequency = 30_000.0 # Hz
num_channels = 1024
num_channels = 384

GiB_to_bytes = 1024**3
full_traces_size_bytes = int(full_traces_size_GiB * GiB_to_bytes)
Expand Down

0 comments on commit 35a9a1c

Please sign in to comment.