Skip to content

Commit

Permalink
use rounded sample times to avoid skips
Browse files Browse the repository at this point in the history
  • Loading branch information
kjohnsen committed Mar 26, 2024
1 parent 8eff537 commit fc66e39
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 15 deletions.
8 changes: 5 additions & 3 deletions cleo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,9 +460,11 @@ def set_io_processor(
return

def communicate_with_io_proc(t):
if io_processor.is_sampling_now(t / ms):
io_processor.put_state(self.get_state(), t / ms)
stim_values = io_processor.get_stim_values(t / ms)
# assuming no one will have timesteps shorter than nanoseconds...
now_ms = round(t / ms, 6)
if io_processor.is_sampling_now(now_ms):
io_processor.put_state(self.get_state(), now_ms)
stim_values = io_processor.get_stim_values(now_ms)
self.update_stimulators(stim_values)

# communication should be at every timestep. The IOProcessor
Expand Down
2 changes: 1 addition & 1 deletion cleo/ioproc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def is_sampling_now(self, query_time_ms):
if np.isclose(query_time_ms % self.sample_period_ms, 0):
return True
elif self.sampling == "when idle":
if query_time_ms % self.sample_period_ms == 0:
if np.isclose(query_time_ms % self.sample_period_ms, 0):
if self._is_currently_idle(query_time_ms):
self._needs_off_schedule_sample = False
return True
Expand Down
31 changes: 20 additions & 11 deletions tests/ioproc/test_processing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Tests for cleo/processing/__init__.py"""
from typing import Any, Tuple

from brian2 import Hz, Network, PoissonGroup, ms, np
from brian2 import Hz, Network, NeuronGroup, ms, np

from cleo import CLSimulator
import cleo
from cleo.ioproc import ConstantDelay, LatencyIOProcessor, ProcessingBlock


Expand Down Expand Up @@ -111,25 +111,34 @@ def test_LatencyIOProcessor_wait_parallel():
_test_LatencyIOProcessor(myLIOP, t, sampling, inputs, outputs)


class SampleCounter(LatencyIOProcessor):
class SampleCounter(cleo.IOProcessor):
"""Just count samples"""

def __init__(self, sample_period_ms, **kwargs):
super().__init__(sample_period_ms, **kwargs)
def is_sampling_now(self, t_query_ms) -> np.bool:
return t_query_ms % self.sample_period_ms == 0

def __init__(self):
self.count = 0
self.sample_period_ms = 1
self.latest_ctrl_signal = {}

def process(self, state_dict: dict, sample_time_ms: float) -> Tuple[dict, float]:
def put_state(self, state_dict: dict, sample_time_ms: float):
self.count += 1
print(sample_time_ms)
return ({}, sample_time_ms)

def get_ctrl_signals(self, query_time_ms: np.float) -> dict:
return {}


def test_no_skip_sampling():
sc = SampleCounter(1)
net = Network(PoissonGroup(1, 100 * Hz))
sim = CLSimulator(net)
sc = SampleCounter()
net = Network()
sim = cleo.CLSimulator(net)
sim.set_io_processor(sc)
sim.run(150 * ms)
assert sc.count == 150
nsamp = 3000
sim.run(nsamp * ms)
assert sc.count == nsamp


class WaveformController(LatencyIOProcessor):
Expand Down

0 comments on commit fc66e39

Please sign in to comment.