Skip to content

Commit

Permalink
v0.16 (#54)
Browse files Browse the repository at this point in the history
* Update t_ms and interval_ms

* updates t in other files

* updated t

* changed more _ms

* added import

* addes imports

* fixed base.py and made changes to lfp  - still failing 10 cases

* Created jupyteer notebook for calculating ratios

* added ratio into opsin library and added 2p action spectra into spectrum variables for all opsins

* progress on all_optical fig

* improve action spectrum extrapolation

* start Newman 15 with alternate opsins

* update Overview doc

* finish newman15 validation w/ alternate opsins

* Update paper plot style

* style_plots_for_paper docstring

* change Jupyter to Python for linguist

* change OGB-1 name

* add warning to scope when sensor not injected

* add showcase to docs

* implement LIOP reset in _base_reset

* add showcase doc to index

* add _base_reset to sim.reset

* fresh run through of all-optical-fig notebook

* add README to notebooks, env info to all_optical_fig.ipynb

* bump to v0.15.0

* remove images [skip ci]

* untrack images

* more organizing notebook images

* move Sridharan fig2

* use Brian units everywhere

* reorganize ioproc tests

* add firing_rate_estimate and pi_ctrl methods, revise PI_Ctrl notebook

* ax ioproc folder

* ax ProcessingBlock

* upgrade nptyping, make sim.remove() private

* fix more unit problems

* get overview working

* get tutorials working

* use explicit units on Light

* add units to SimpleOpsin gain

* extend font stack

* fix spike count problem

* scope to_neo

* 2p light to_neo

* fix spiking test fail

* Use utilities.rng

* units on Light.values and other small tweaks

* add version note to newman15 notebooks

* add version note to opto val notebook

* fix electrodes tutorial

* fix opto tutorial

* fix multi_opto notebook, add range to plot_spectra

* fix on_off_ctrl notebook

* update PI ctrl notebook

* fix all_optical tutorial

* fix video viz tutorial

* fix Neo tutorial

* fix advanced LFP notebook

* rename firing_rate_estimate to exp_firing_rate_estimate

* fix global set_seed

* progress on lqr tutorial, before cutting adaptive control

* revamp ldsCtrlEst tutorial

* support numpy >= 2.0

* remove bug workaround (closes Remove bug workaround in tutorials #16)

* fix light.to_neo

* replace nptyping with jaxtyping

* remove imported members from ioproc docs

* mark lfp tests as slow

* polish tutorial notebooks

* tweak docstrings

* add LatencyIOProcessor to base cleo import

* clean up imports

* fix lqr notebook, fresh rerun of all

* polish video tutorial, exit if no ffmpeg

* skip execution of video write cell

* tweak overview docs, increase test cell timeout

* simplify 4-state opsin model

* eliminate warnings

* rerun opto tutorial, increase timestep

---------

Co-authored-by: Arnav Tripathi <[email protected]>
  • Loading branch information
kjohnsen and arnavt2955 authored Sep 18, 2024
1 parent ebf9f26 commit 76c0736
Show file tree
Hide file tree
Showing 76 changed files with 8,440 additions and 218,361 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ tmp/*
**/tmp/*
# figures from notebooks not saved by default
notebooks/img/*
!notebooks/img/orig/
notebooks/results/*
docs/tutorials/*.svg

Expand Down
8 changes: 4 additions & 4 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ The easiest way is to enable Ruff as the formatter in your IDE with auto-formatt
I was going to lint using flake8 but then I realized, this is a small research code package! We don't need super pretty, consistent code. Just try to follow Python conventions and use Black.

## Structure
Originally, the intention was for opto and electrodes to live under stimulators and recorders, respectively. This made `opto_stim = cleo.opto.OptogeneticIntervention(...)` possible but not for importing from that second-level shortcut (`from cleo.opto import ...`). Thus, they were moved up a level.

We still have some import shortcuts for users, making everything in the `ephys` subpackage (the contents of lfp, spiking, and probes modules) available under `cleo.ephys`. We do this by importing the submodules' contents in `__init__.py` files. We can then test the shortcut imports by making sure to use them in the unit tests. However, we must use the full import path in the source code itself to avoid circular import errors.
We structure big modules like `ephys` as a folder with an `__init__.py` that imports from the submodules (`spiking`, `lfp`, etc.).
This allows us to structure the code nicely but still allow for short imports.
We can then test the shortcut imports by making sure to use them in the unit tests. However, we must use the full import path in the source code itself to avoid circular import errors.

## Notebooks
Please use [nbdev for Git-friendly Jupyter](https://nbdev.fast.ai/tutorials/git_friendly_jupyter.html).
Please use [nbdev for Git-friendly Jupyter](https://nbdev.fast.ai/tutorials/git_friendly_jupyter.html), especially `nbdev_clean` before committing Jupyter notebooks.
1 change: 1 addition & 0 deletions cleo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
Stimulator,
SynapseDevice,
)
from cleo.ioproc import LatencyIOProcessor
97 changes: 61 additions & 36 deletions cleo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
from matplotlib.artist import Artist
from mpl_toolkits.mplot3d import Axes3D

import cleo
from cleo.registry import registry_for_sim
from cleo.utilities import add_to_neo_segment, analog_signal, brian_safe_name
from cleo.utilities import add_to_neo_segment, brian_safe_name, rng, unit_safe_append


class NeoExportable(ABC):
Expand Down Expand Up @@ -158,20 +159,20 @@ class IOProcessor(ABC):
class more useful, since delay handling is already defined.
"""

sample_period_ms: float = 1
sample_period: Quantity = 1 * ms
"""Determines how frequently the processor takes samples"""

latest_ctrl_signal: dict = field(factory=dict, init=False, repr=False)
"""The most recent control signal returned by :meth:`get_ctrl_signals`"""

@abstractmethod
def is_sampling_now(self, time) -> bool:
def is_sampling_now(self, t_now: Quantity) -> bool:
"""Determines whether the processor will take a sample at this timestep.
Parameters
----------
time : Brian 2 temporal Unit
Current timestep.
t_now : Quantity
Current time.
Returns
-------
Expand All @@ -180,27 +181,27 @@ def is_sampling_now(self, time) -> bool:
pass

@abstractmethod
def put_state(self, state_dict: dict, sample_time_ms: float) -> None:
def put_state(self, state_dict: dict, t_samp: Quantity) -> None:
"""Deliver network state to the :class:`IOProcessor`.
Parameters
----------
state_dict : dict
A dictionary of recorder measurements, as returned by
:func:`~cleo.CLSimulator.get_state()`
sample_time_ms: float
t_samp: Quantity
The current simulation timestep. Essential for simulating
control latency and for time-varying control.
"""
pass

@abstractmethod
def get_ctrl_signals(self, query_time_ms: float) -> dict:
def get_ctrl_signals(self, t_query: Quantity) -> dict:
"""Get per-stimulator control signal from the :class:`~cleo.IOProcessor`.
Parameters
----------
query_time_ms : float
t_query : Quantity
Current simulation time.
Returns
Expand All @@ -210,16 +211,16 @@ def get_ctrl_signals(self, query_time_ms: float) -> dict:
"""
pass

def get_stim_values(self, query_time_ms: float) -> dict:
ctrl_signals = self.get_ctrl_signals(query_time_ms)
def get_stim_values(self, t_query: Quantity) -> dict:
ctrl_signals = self.get_ctrl_signals(t_query)
self.latest_ctrl_signal.update(ctrl_signals)
stim_value_conversions = self.preprocess_ctrl_signals(
self.latest_ctrl_signal, query_time_ms
self.latest_ctrl_signal, t_query
)
return ctrl_signals | stim_value_conversions

def preprocess_ctrl_signals(
self, latest_ctrl_signals: dict, query_time_ms: float
self, latest_ctrl_signals: dict, t_query: Quantity
) -> dict:
"""Preprocess control signals as needed to control stimulator waveforms between samples.
Expand All @@ -235,7 +236,7 @@ def preprocess_ctrl_signals(
Parameters
----------
query_time_ms : float
t_query : float
Current simulation time.
Returns
Expand Down Expand Up @@ -270,28 +271,21 @@ def get_state(self) -> Any:
class Stimulator(InterfaceDevice, NeoExportable):
"""Device for manipulating the network"""

value: Any = field(init=False, default=None)
value: Any = field(init=False, default=0)
"""The current value of the stimulator device"""
default_value: Any = 0
"""The default value of the device---used on initialization and on :meth:`~reset`"""
t_ms: list[float] = field(factory=list, init=False, repr=False)
t: Quantity = field(factory=lambda: np.array([]) * ms, init=False, repr=False)
"""Times stimulator was updated, stored if :attr:`~cleo.InterfaceDevice.save_history`"""
values: list[Any] = field(factory=list, init=False, repr=False)
"""Values taken by the stimulator at each :meth:`~update` call,
stored if :attr:`~cleo.InterfaceDevice.save_history`"""

def __attrs_post_init__(self):
self.value = self.default_value
self._init_saved_vars()
self.reset()

def _init_saved_vars(self):
if self.save_history:
if self.sim:
t0 = self.sim.network.t / ms
else:
t0 = 0
self.t_ms = [t0]
self.values = [self.value]
self.t = [] * ms
self.values = []

def update(self, ctrl_signal) -> None:
"""Set the stimulator value.
Expand All @@ -308,16 +302,20 @@ def update(self, ctrl_signal) -> None:
"""
self.value = ctrl_signal
if self.save_history:
self.t_ms.append(self.sim.network.t / ms)
if self.sim:
t = self.sim.network.t
else:
t = 0 * ms
self.t = unit_safe_append(self.t, t)
self.values.append(self.value)

def reset(self, **kwargs) -> None:
"""Reset the stimulator device to a neutral state"""
self.value = self.default_value
self._init_saved_vars()
self.update(0)

def to_neo(self):
signal = analog_signal(self.t_ms, self.values, "dimensionless")
signal = cleo.utilities.analog_signal(self.t, self.values, "dimensionless")
signal.name = self.name
signal.description = "Exported from Cleo stimulator device"
signal.annotate(export_datetime=datetime.datetime.now())
Expand Down Expand Up @@ -407,6 +405,32 @@ def inject(
self.devices.add(device)
return self

def _remove(self, device: InterfaceDevice) -> CLSimulator:
"""Remove device and associated Brian objects from the simulation (UNTESTED).
Parameters
----------
device : InterfaceDevice
Device to remove
Returns
-------
CLSimulator
self
"""
for brian_object in device.brian_objects:
if brian_object in self.network.objects:
self.network.remove(brian_object)
if isinstance(device, Recorder):
if device.name in self.recorders:
del self.recorders[device.name]
if isinstance(device, Stimulator):
if device.name in self.stimulators:
del self.stimulators[device.name]
self.devices.remove(device)
self.network.store(self._net_store_name)
return self

def get_state(self) -> dict:
"""Return current recorder measurements.
Expand Down Expand Up @@ -462,10 +486,10 @@ def set_io_processor(

def communicate_with_io_proc(t):
# assuming no one will have timesteps shorter than nanoseconds...
now_ms = round(t / ms, 6)
if io_processor.is_sampling_now(now_ms):
io_processor.put_state(self.get_state(), now_ms)
stim_values = io_processor.get_stim_values(now_ms)
t_now = round(t / ms, 6) * ms
if io_processor.is_sampling_now(t_now):
io_processor.put_state(self.get_state(), t_now)
stim_values = io_processor.get_stim_values(t_now)
self.update_stimulators(stim_values)

# communication should be at every timestep. The IOProcessor
Expand Down Expand Up @@ -643,7 +667,7 @@ def connect_to_neuron_group(self, neuron_group: NeuronGroup, **kwparams) -> None
if "i_targets" in kwparams:
raise ValueError("p_expression and i_targets are incompatible")
p_expression = kwparams.get("p_expression", 1)
expr_bool = np.random.rand(neuron_group.N) < p_expression
expr_bool = rng.random(neuron_group.N) < p_expression
i_targets = np.where(expr_bool)[0]
elif "i_targets" in kwparams:
i_targets = kwparams["i_targets"]
Expand All @@ -670,7 +694,8 @@ def connect_to_neuron_group(self, neuron_group: NeuronGroup, **kwparams) -> None

# store at the end, after all checks have passed
self.source_ngs[neuron_group.name] = source_ng
self.brian_objects.add(source_ng)
if source_ng is not neuron_group:
self.brian_objects.add(source_ng)
self.synapses[neuron_group.name] = syn
self.brian_objects.add(syn)

Expand Down Expand Up @@ -707,7 +732,7 @@ def modify_model_and_params_for_ng(
A tuple containing an Equations object
and a parameter dictionary, constructed from :attr:`~model`
and :attr:`~params`, respectively, with modified names for use
in :attr:`~cleo.opto.OptogeneticIntervention.synapses`
in :attr:`synapses`
"""
model = self.model

Expand Down
20 changes: 11 additions & 9 deletions cleo/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from cleo.utilities import (
modify_model_with_eqs,
rng,
uniform_cylinder_rθz,
xyz_from_rθz,
)
Expand Down Expand Up @@ -83,9 +84,9 @@ def assign_coords_rand_rect_prism(
unit : Unit, optional
Brian unit to specify scale implied in limits, by default mm
"""
x = (xlim[1] - xlim[0]) * np.random.random(len(neuron_group)) + xlim[0]
y = (ylim[1] - ylim[0]) * np.random.random(len(neuron_group)) + ylim[0]
z = (zlim[1] - zlim[0]) * np.random.random(len(neuron_group)) + zlim[0]
x = (xlim[1] - xlim[0]) * rng.random(len(neuron_group)) + xlim[0]
y = (ylim[1] - ylim[0]) * rng.random(len(neuron_group)) + ylim[0]
z = (zlim[1] - zlim[0]) * rng.random(len(neuron_group)) + zlim[0]
assign_xyz(neuron_group, x, y, z, unit)


Expand Down Expand Up @@ -114,10 +115,10 @@ def assign_coords_rand_cylinder(
xyz_start = np.array(xyz_start)
xyz_end = np.array(xyz_end)
# sample uniformly over r**2 for equal area
rs = np.sqrt(radius**2 * np.random.random(len(neuron_group)))
thetas = 2 * np.pi * np.random.random(len(neuron_group))
rs = np.sqrt(radius**2 * rng.random(len(neuron_group)))
thetas = 2 * np.pi * rng.random(len(neuron_group))
cyl_length = np.linalg.norm(xyz_end - xyz_start)
z_cyls = cyl_length * np.random.random(len(neuron_group))
z_cyls = cyl_length * rng.random(len(neuron_group))

xs, ys, zs = xyz_from_rθz(rs, thetas, z_cyls, xyz_start, xyz_end)

Expand Down Expand Up @@ -196,9 +197,10 @@ def coords_from_xyz(x: Quantity, y: Quantity, z: Quantity) -> Quantity:
return (
np.concatenate(
[
np.reshape(x / meter, (*x.shape, 1)),
np.reshape(y / meter, (*y.shape, 1)),
np.reshape(z / meter, (*z.shape, 1)),
# use [:] to work around VariableView.shape getting parent group shape
np.reshape(x / meter, (*x[:].shape, 1)),
np.reshape(y / meter, (*y[:].shape, 1)),
np.reshape(z / meter, (*z[:].shape, 1)),
],
axis=-1,
)
Expand Down
Loading

0 comments on commit 76c0736

Please sign in to comment.