Skip to content

Commit

Permalink
multi-Light-opsin refactor
Browse files Browse the repository at this point in the history
* update deps

* refactor opsin with attrs

* begin attrs/light-opsin split

(tests not working yet for 4-state; need to go rest of the way to light-aggregator neurons)

* split opto into light and opsins

* messy progress

* rename validation to notebooks

* use nbdev git-friendly jupyter

* add multi-wavelength model notebook

* tweak multi-wavelength model NB math

* finish opsin tests

* pivoting to single light neuron group
almost working, but I forgot you can't have multiple Synapses objects
sum to the same variable in a single NeuronGroup.

* get opto refactor working (other stuff is broken)

* fix broken tests after device/opto refactor

* refactor signals too, pass tests

* rename tests to match source

* pass all tests

* progress on multi light-opto tutorial

* finish multi-opto tutorial

* try to fix CI config

* fix annotations for Python <3.9
  • Loading branch information
kjohnsen authored Jul 7, 2023
1 parent d60192a commit 448792c
Show file tree
Hide file tree
Showing 85 changed files with 200,597 additions and 98,591 deletions.
1 change: 1 addition & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.ipynb merge=nbdev-merge
11 changes: 11 additions & 0 deletions .gitconfig
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Generated by nbdev_install_hooks
#
# If you need to disable this instrumentation do:
# git config --local --unset include.path
#
# To restore:
# git config --local include.path ../.gitconfig
#
[merge "nbdev-merge"]
name = resolve conflicts with nbdev_fix
driver = nbdev_merge %O %A %B %P
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
strategy:
matrix:
os: [macos-latest, windows-latest, ubuntu-latest]
python: [3.7, 3.9]
python: [3.8, 3.9, "3.10", 3.11]

# Steps represent a sequence of tasks that will be executed as part of the job
steps:
Expand Down
5 changes: 4 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -59,4 +59,7 @@ I was going to lint using flake8 but then I realized, this is a small research c
## Structure
Originally, the intention was for opto and electrodes to live under stimulators and recorders, respectively. This made `opto_stim = cleo.opto.OptogeneticIntervention(...)` possible but not for importing from that second-level shortcut (`from cleo.opto import ...`). Thus, they were moved up a level.

We still have some import shortcuts for users, making everything in the `ephys` subpackage (the contents of lfp, spiking, and probes modules) available under `cleo.ephys`. We do this by importing the submodules' contents in `__init__.py` files. We can then test the shortcut imports by making sure to use them in the unit tests. However, we must use the full import path in the source code itself to avoid circular import errors.
We still have some import shortcuts for users, making everything in the `ephys` subpackage (the contents of lfp, spiking, and probes modules) available under `cleo.ephys`. We do this by importing the submodules' contents in `__init__.py` files. We can then test the shortcut imports by making sure to use them in the unit tests. However, we must use the full import path in the source code itself to avoid circular import errors.

## Notebooks
Please use [nbdev for Git-friendly Jupyter](https://nbdev.fast.ai/tutorials/git_friendly_jupyter.html).
2 changes: 2 additions & 0 deletions cleo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

# auto-import submodules
import cleo.ephys
import cleo.opto
Expand Down
188 changes: 66 additions & 122 deletions cleo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
from typing import Any, Tuple, Iterable

from attrs import define, field
from brian2 import (
NeuronGroup,
Subgroup,
Expand All @@ -19,33 +20,26 @@
from matplotlib.artist import Artist


@define(eq=False)
class InterfaceDevice(ABC):
"""Base class for devices to be injected into the network"""

name: str
"""Unique identifier for device.
Used as a key in output/input dicts
"""
brian_objects: set
brian_objects: set = field(factory=set, init=False)
"""All the Brian objects added to the network by this device.
Must be kept up-to-date in :meth:`connect_to_neuron_group` and
other functions so that those objects can be automatically added
to the network when the device is injected.
"""
sim: CLSimulator
"""The simulator the device is injected into
"""
sim: CLSimulator = field(init=False, default=None)
"""The simulator the device is injected into """

def __init__(self, name: str) -> None:
"""
Parameters
----------
name : str
Unique identifier for the device.
"""
self.name = name
self.brian_objects = set()
self.sim = None
name: str = field(kw_only=True)
"""Unique identifier for device, used in sampling, plotting, etc.
Name of the class by default."""

@name.default
def _default_name(self) -> str:
return self.__class__.__name__

def init_for_simulator(self, simulator: CLSimulator) -> None:
"""Initialize device for simulator on initial injection
Expand All @@ -62,6 +56,10 @@ def init_for_simulator(self, simulator: CLSimulator) -> None:
"""
pass

def reset(self, **kwargs) -> None:
"""Reset the device to a neutral state"""
pass

@abstractmethod
def connect_to_neuron_group(self, neuron_group: NeuronGroup, **kwparams) -> None:
"""Connect device to given `neuron_group`.
Expand All @@ -73,8 +71,8 @@ def connect_to_neuron_group(self, neuron_group: NeuronGroup, **kwparams) -> None
Parameters
----------
neuron_group : NeuronGroup
**kwparams : optional, passed from `inject_recorder` or
`inject_stimulator`
**kwparams : optional, passed from `inject` or
`inject`
"""
pass

Expand All @@ -101,7 +99,7 @@ def add_self_to_plot(
"""
return []

def update_artists(artists: list[Artist], *args, **kwargs) -> list[Artist]:
def update_artists(self, artists: list[Artist], *args, **kwargs) -> list[Artist]:
"""Update the artists used to render the device
Used to set the artists' state at every frame of a video visualization.
Expand Down Expand Up @@ -183,6 +181,7 @@ def reset(self, **kwargs) -> None:
pass


@define(eq=False)
class Recorder(InterfaceDevice):
"""Device for taking measurements of the network."""

Expand All @@ -191,48 +190,30 @@ def get_state(self) -> Any:
"""Return current measurement."""
pass

def reset(self, **kwargs) -> None:
"""Reset the recording device to a neutral state"""
pass


@define(eq=False)
class Stimulator(InterfaceDevice):
"""Device for manipulating the network"""

value: Any
value: Any = field(init=False, default=None)
"""The current value of the stimulator device"""
default_value: Any
default_value: Any = 0
"""The default value of the device---used on initialization and on :meth:`~reset`"""
t_ms: list[float]
t_ms: list[float] = field(factory=list, init=False, repr=False)
"""Times stimulator was updated, stored if :attr:`save_history`"""
values: list[Any]
values: list[Any] = field(factory=list, init=False, repr=False)
"""Values taken by the stimulator at each :meth:`~update` call,
stored if :attr:`save_history`"""
save_history: bool
save_history: bool = True
"""Determines whether :attr:`t_ms` and :attr:`values` are recorded"""

def __init__(self, name: str, default_value, save_history: bool = False) -> None:
"""
Parameters
----------
name : str
Unique device name used in :meth:`CLSimulator.update_stimulators`
default_value : any
The stimulator's default value
"""
super().__init__(name)
self.value = default_value
self.default_value = default_value
self.save_history = save_history

def init_for_simulator(self, simulator: CLSimulator) -> None:
super().init_for_simulator(simulator)
self._init_saved_vars()
def __attrs_post_init__(self):
self.value = self.default_value

def _init_saved_vars(self):
if self.save_history:
self.t_ms = [self.sim.network.t / ms]
self.values = [self.default_value]
self.t_ms = []
self.values = []

def update(self, ctrl_signal) -> None:
"""Set the stimulator value.
Expand All @@ -254,47 +235,41 @@ def update(self, ctrl_signal) -> None:

def reset(self, **kwargs) -> None:
"""Reset the stimulator device to a neutral state"""
self.value = self.default_value
self._init_saved_vars()


@define(eq=False)
class CLSimulator:
"""The centerpiece of cleo. Integrates simulation components and runs."""

io_processor: IOProcessor
network: Network
recorders = "set[Recorder]"
stimulators = "set[Stimulator]"
_processing_net_op: NetworkOperation
_net_store_name: str = "cleo default"

def __init__(self, brian_network: Network) -> None:
"""
Parameters
----------
brian_network : Network
The Brian network forming the core model
"""
self.network = brian_network
self.stimulators = {}
self.recorders = {}
self.io_processor = None
self._processing_net_op = None
network: Network = field(repr=False)
"""The Brian network forming the core model"""
io_processor: IOProcessor = field(default=None, init=False)
recorders: dict[str, Recorder] = field(factory=dict, init=False, repr=False)
stimulators: dict[str, Stimulator] = field(factory=dict, init=False, repr=False)
devices: set[InterfaceDevice] = field(factory=set, init=False)
_processing_net_op: NetworkOperation = field(default=None, init=False, repr=False)
_net_store_name: str = field(default="cleo default", init=False, repr=False)

def inject_device(
def inject(
self, device: InterfaceDevice, *neuron_groups: NeuronGroup, **kwparams: Any
) -> None:
) -> CLSimulator:
"""Inject InterfaceDevice into the network, connecting to specified neurons.
Calls :meth:`~InterfaceDevice.connect_to_neuron_group` for each group with
kwparams and adds the device's :attr:`~InterfaceDevice.brian_objects`
to the simulator's :attr:`network`.
Automatically called by :meth:`inject_recorder` and :meth:`inject_stimulator`.
Parameters
----------
device : InterfaceDevice
Device to inject
Returns
-------
CLSimulator
self
"""
if len(neuron_groups) == 0:
raise Exception("Injecting stimulator for no neuron groups is meaningless.")
Expand Down Expand Up @@ -323,50 +298,15 @@ def inject_device(
device.init_for_simulator(self)
device.connect_to_neuron_group(ng, **kwparams)
for brian_object in device.brian_objects:
self.network.add(brian_object)
if brian_object not in self.network.objects:
self.network.add(brian_object)
self.network.store(self._net_store_name)

def inject_stimulator(
self, stimulator: Stimulator, *neuron_groups: NeuronGroup, **kwparams
) -> None:
"""Inject stimulator into given neuron groups.
:meth:`Stimulator.connect_to_neuron_group` is called for each `group`.
Parameters
----------
stimulator : Stimulator
The stimulator to inject
*neuron_groups : NeuronGroup
The groups to inject the stimulator into
**kwparams : any
Passed on to :meth:`Stimulator.connect_to_neuron_group` function.
Necessary for parameters that can vary by neuron group, such
as opsin expression levels.
"""
self.inject_device(stimulator, *neuron_groups, **kwparams)
self.stimulators[stimulator.name] = stimulator

def inject_recorder(
self, recorder: Recorder, *neuron_groups: NeuronGroup, **kwparams
) -> None:
"""Inject recorder into given neuron groups.
:meth:`Recorder.connect_to_neuron_group` is called for each `group`.
Parameters
----------
recorder : Recorder
The recorder to inject into the simulation
*neuron_groups : NeuronGroup
The groups to inject the recorder into
**kwparams : any
Passed on to :meth:`Recorder.connect_to_neuron_group` function.
Necessary for parameters that can vary by neuron group, such
as inhibitory/excitatory cell type
"""
self.inject_device(recorder, *neuron_groups, **kwparams)
self.recorders[recorder.name] = recorder
if isinstance(device, Recorder):
self.recorders[device.name] = device
if isinstance(device, Stimulator):
self.stimulators[device.name] = device
self.devices.add(device)
return self

def get_state(self) -> dict:
"""Return current recorder measurements.
Expand Down Expand Up @@ -396,7 +336,7 @@ def update_stimulators(self, ctrl_signals) -> None:
for name, signal in ctrl_signals.items():
self.stimulators[name].update(signal)

def set_io_processor(self, io_processor, communication_period=None) -> None:
def set_io_processor(self, io_processor, communication_period=None) -> CLSimulator:
"""Set simulator IO processor
Will replace any previous IOProcessor so there is only one at a time.
Expand All @@ -406,6 +346,11 @@ def set_io_processor(self, io_processor, communication_period=None) -> None:
Parameters
----------
io_processor : IOProcessor
Returns
-------
CLSimulator
self
"""
self.io_processor = io_processor
# remove previous NetworkOperation
Expand All @@ -431,6 +376,7 @@ def communicate_with_io_proc(t):
)
self.network.add(self._processing_net_op)
self.network.store(self._net_store_name)
return self

def run(self, duration: Quantity, **kwparams) -> None:
"""Run simulation.
Expand All @@ -451,13 +397,11 @@ def reset(self, **kwargs):
Restores the Brian Network to where it was when the
CLSimulator was last modified (last injection, IOProcessor change).
Calls reset() on stimulators, recorders, and IOProcessor.
Calls reset() on devices and IOProcessor.
"""
# kwargs passed to stimulators, recorders, and io_processor reset
self.network.restore(self._net_store_name)
for stim in self.stimulators.values():
stim.reset(**kwargs)
for rec in self.recorders.values():
rec.reset(**kwargs)
for device in self.devices:
device.reset(**kwargs)
if self.io_processor is not None:
self.io_processor.reset(**kwargs)
Loading

0 comments on commit 448792c

Please sign in to comment.