Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

patches: fix Spiking.t_samp_ms and IOProc sample skipping #47

Merged
merged 2 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cleo/ephys/spiking.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def _update_saved_vars(self, t_ms, i, t_samp_ms):
if self.probe.save_history:
self.i = np.concatenate([self.i, i])
self.t_ms = np.concatenate([self.t_ms, t_ms])
self.t_samp_ms = np.concatenate([self.t_samp_ms, [t_samp_ms]])
t_samp_ms_rep = np.full_like(t_ms, t_samp_ms)
self.t_samp_ms = np.concatenate([self.t_samp_ms, t_samp_ms_rep])

def connect_to_neuron_group(
self, neuron_group: NeuronGroup, **kwparams
Expand Down
3 changes: 2 additions & 1 deletion cleo/ioproc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ def _is_currently_idle(self, query_time_ms):

def is_sampling_now(self, query_time_ms):
if self.sampling == "fixed":
if np.isclose(query_time_ms % self.sample_period_ms, 0):
resid = query_time_ms % self.sample_period_ms
if np.isclose(resid, 0) or np.isclose(resid, self.sample_period_ms):
return True
elif self.sampling == "when idle":
if np.isclose(query_time_ms % self.sample_period_ms, 0):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "cleosim"
version = "0.14.1"
version = "0.14.2"
description = "Cleo: the Closed-Loop, Electrophysiology, and Optogenetics experiment simulation testbed"
authors = [
"Kyle Johnsen <[email protected]>",
Expand Down
15 changes: 9 additions & 6 deletions tests/ephys/test_spiking.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
"""Tests for ephys.spiking module"""
import pytest
import numpy as np
from brian2 import SpikeGeneratorGroup, ms, mm, Network
import neo
import quantities as pq
from brian2 import Network, SpikeGeneratorGroup, mm, ms

from cleo import CLSimulator
from cleo.ephys import *
from cleo.ephys import MultiUnitSpiking, Probe, SortedSpiking
from cleo.ioproc import RecordOnlyProcessor


Expand Down Expand Up @@ -81,7 +79,7 @@ def test_MUS_multiple_contacts():
assert np.sum(mus.i == 0) < len(indices)
assert np.sum(mus.i == 1) < len(indices)

assert len(mus.i) == len(mus.t_ms)
assert len(mus.i) == len(mus.t_ms) == len(mus.t_samp_ms)


def test_MUS_multiple_groups():
Expand All @@ -106,6 +104,8 @@ def test_MUS_multiple_groups():
assert 20 < np.sum(mus.i == 0) < 60
# second channel would have caught all spikes from sgg1 and sgg2
assert np.sum(mus.i == 1) == 60
assert len(mus.t_ms) == len(mus.t_samp_ms)
assert np.all(mus.t_samp_ms == 10)


def test_MUS_reset():
Expand Down Expand Up @@ -156,7 +156,10 @@ def test_SortedSpiking():
assert all(i == [2, 3, 5])

for i in (0, 1, 4):
assert not i in ss.i
assert i not in ss.i

assert ss.t_ms.shape == ss.i.shape == ss.t_samp_ms.shape
assert np.all(np.in1d(ss.t_samp_ms, [3, 4, 5, 6]))


def _test_reset(spike_signal_class):
Expand Down
31 changes: 23 additions & 8 deletions tests/ioproc/test_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,18 @@ def __init__(self, sample_period_ms, **kwargs):
super().__init__(sample_period_ms, **kwargs)
self.delay = 1.199
self.component = MyProcessingBlock(delay=ConstantDelay(self.delay))
self.count = 0

def process(self, state_dict: dict, sample_time_ms: float) -> Tuple[dict, float]:
input = state_dict["in"]
out, out_t = self.component.process(
input, sample_time_ms, measurement_time=sample_time_ms
)
return {"out": out}, out_t
try:
input = state_dict["in"]
out, out_t = self.component.process(
input, sample_time_ms, measurement_time=sample_time_ms
)
return {"out": out}, out_t
except KeyError:
self.count += 1
return {}, sample_time_ms


def _test_LatencyIOProcessor(myLIOP, t, sampling, inputs, outputs):
Expand Down Expand Up @@ -117,14 +122,13 @@ class SampleCounter(cleo.IOProcessor):
def is_sampling_now(self, t_query_ms) -> np.bool:
return t_query_ms % self.sample_period_ms == 0

def __init__(self):
def __init__(self, sample_period_ms=1):
self.count = 0
self.sample_period_ms = 1
self.sample_period_ms = sample_period_ms
self.latest_ctrl_signal = {}

def put_state(self, state_dict: dict, sample_time_ms: float):
self.count += 1
print(sample_time_ms)
return ({}, sample_time_ms)

def get_ctrl_signals(self, query_time_ms: np.float) -> dict:
Expand All @@ -141,6 +145,17 @@ def test_no_skip_sampling():
assert sc.count == nsamp


def test_no_skip_sampling_short():
net = Network()
sim = cleo.CLSimulator(net)
Tsamp = 0.2 * ms
liop = MyLIOP(Tsamp / ms)
sim.set_io_processor(liop)
nsamp = 20
sim.run(nsamp * Tsamp)
assert liop.count == nsamp


class WaveformController(LatencyIOProcessor):
def process(self, state_dict, t_ms):
return {"steady": t_ms, "time-varying": t_ms + 1}, t_ms + 3
Expand Down
Loading