Skip to content

Commit

Permalink
Misc improvements (#41)
Browse files Browse the repository at this point in the history
* improve plot function

* add Köhler illumination

* fix numpy warning product->prod

* add eNpHR3.0, 3-state Cl- pump model

* update tklfp

* give warning when action spectrum unspecified

* remove old _tutorials.rst file

* start newman15 validation notebook

* clean up imports

* support .-() in opsin names

* remove sim from device repr

* refactor IOProc to use preprocess_ctrl_signals

* add light.color property

* progress

* add notebook results to .gitignore

* tweak enphr3

* finish newman 15 validation notebook

* tweak scope viz

* bump to v0.13.0
  • Loading branch information
kjohnsen authored Feb 1, 2024
1 parent 865cb98 commit fe2cd33
Show file tree
Hide file tree
Showing 27 changed files with 2,084 additions and 336 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ tmp/*
**/tmp/*
# figures from notebooks not saved by default
notebooks/img/*
notebooks/results/*

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
16 changes: 8 additions & 8 deletions cleo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
"""Contains core classes and functions for the Cleo package."""
from __future__ import annotations

import cleo.coords

# auto-import submodules
import cleo.ephys
import cleo.imaging
import cleo.ioproc
import cleo.opto
import cleo.coords
import cleo.stimulators
import cleo.recorders
import cleo.ioproc
import cleo.registry
import cleo.stimulators
import cleo.utilities
import cleo.viz
import cleo.imaging
import cleo.registry

from cleo.base import (
CLSimulator,
Recorder,
Stimulator,
InterfaceDevice,
IOProcessor,
Recorder,
Stimulator,
SynapseDevice,
)
109 changes: 77 additions & 32 deletions cleo/base.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
"""Contains definitions for essential, base classes."""

from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Tuple, Iterable

import datetime
from abc import ABC, abstractmethod
from typing import Any, Tuple

from attrs import define, field, asdict, fields_dict
import neo
from attrs import asdict, define, field, fields_dict
from brian2 import (
np,
NeuronGroup,
BrianObjectException,
Equations,
Synapses,
Subgroup,
Network,
NetworkOperation,
NeuronGroup,
Quantity,
Subgroup,
Synapses,
Unit,
defaultclock,
ms,
Unit,
Quantity,
BrianObjectException,
np,
)
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.artist import Artist
import neo
from cleo.registry import registry_for_sim
from mpl_toolkits.mplot3d import Axes3D

import cleo.utilities
from cleo.registry import registry_for_sim
from cleo.utilities import add_to_neo_segment, analog_signal, brian_safe_name


class NeoExportable(ABC):
Expand Down Expand Up @@ -54,7 +54,7 @@ class InterfaceDevice(ABC):
other functions so that those objects can be automatically added
to the network when the device is injected.
"""
sim: CLSimulator = field(init=False, default=None)
sim: CLSimulator = field(init=False, default=None, repr=False)
"""The simulator the device is injected into """
name: str = field(kw_only=True)
"""Identifier for device, used in sampling, plotting, etc.
Expand Down Expand Up @@ -149,6 +149,7 @@ def update_artists(self, artists: list[Artist], *args, **kwargs) -> list[Artist]
return []


@define
class IOProcessor(ABC):
"""Abstract class for implementing sampling, signal processing and control
Expand All @@ -157,9 +158,12 @@ class IOProcessor(ABC):
class more useful, since delay handling is already defined.
"""

sample_period_ms: float
sample_period_ms: float = 1
"""Determines how frequently the processor takes samples"""

latest_ctrl_signal: dict = field(factory=dict, init=False, repr=False)
"""The most recent control signal returned by :meth:`get_ctrl_signals`"""

@abstractmethod
def is_sampling_now(self, time) -> bool:
"""Determines whether the processor will take a sample at this timestep.
Expand Down Expand Up @@ -191,7 +195,7 @@ def put_state(self, state_dict: dict, sample_time_ms: float) -> None:
pass

@abstractmethod
def get_ctrl_signal(self, query_time_ms: float) -> dict:
def get_ctrl_signals(self, query_time_ms: float) -> dict:
"""Get per-stimulator control signal from the :class:`~cleo.IOProcessor`.
Parameters
Expand All @@ -202,10 +206,51 @@ def get_ctrl_signal(self, query_time_ms: float) -> dict:
Returns
-------
dict
A {'stimulator_name': value} dictionary for updating stimulators.
A {'stimulator_name': ctrl_signal} dictionary for updating stimulators.
"""
pass

def get_stim_values(self, query_time_ms: float) -> dict:
ctrl_signals = self.get_ctrl_signals(query_time_ms)
self.latest_ctrl_signal.update(ctrl_signals)
stim_value_conversions = self.preprocess_ctrl_signals(
self.latest_ctrl_signal, query_time_ms
)
return ctrl_signals | stim_value_conversions

def preprocess_ctrl_signals(
self, latest_ctrl_signals: dict, query_time_ms: float
) -> dict:
"""Preprocess control signals as needed to control stimulator waveforms between samples.
I.e., if a control signal defines the frequency of a periodic light stimulus, this
function computes the current intensity given the latest frequency and the current
time. This is called immediately after :meth:`get_ctrl_signals` and on every timestep
to update the stimulator waveform between samples.
This only needs to be implemented when a stimulus that varies between samples is desired.
Otherwise, the control signal returned by :meth:`get_ctrl_signals` is used directly.
If not all stimulators need this functionality, only return a dict for those that do.
The original, unprocessed control signal is used for the others.
Parameters
----------
query_time_ms : float
Current simulation time.
Returns
-------
dict
A {'stimulator_name': value} dictionary for updating stimulators.
"""
return {}

def get_intersample_ctrl_signal(self, query_time_ms: float) -> dict:
"""Get per-stimulator control signal between samples. I.e., for implementing
a time-varying waveform based on parameters from the last sample.
Such parameters will need to be stored in the :class:`~cleo.IOProcessor`."""
return {}

def reset(self, **kwargs) -> None:
pass

Expand Down Expand Up @@ -271,7 +316,7 @@ def reset(self, **kwargs) -> None:
self._init_saved_vars()

def to_neo(self):
signal = cleo.utilities.analog_signal(self.t_ms, self.values, "dimensionless")
signal = analog_signal(self.t_ms, self.values, "dimensionless")
signal.name = self.name
signal.description = "Exported from Cleo stimulator device"
signal.annotate(export_datetime=datetime.datetime.now())
Expand Down Expand Up @@ -375,21 +420,21 @@ def get_state(self) -> dict:
state[name] = recorder.get_state()
return state

def update_stimulators(self, ctrl_signals) -> None:
def update_stimulators(self, stim_values: dict[str, Any]) -> None:
"""Update stimulators with output from the :class:`IOProcessor`
Parameters
----------
ctrl_signals : dict
{`stimulator_name`: `ctrl_signal`} dictionary with values
stim_values : dict
{`stimulator_name`: `stim_value`} dictionary with values
to update each stimulator.
"""
if ctrl_signals is None:
return
for name, signal in ctrl_signals.items():
self.stimulators[name].update(signal)
for name, value in stim_values.items():
self.stimulators[name].update(value)

def set_io_processor(self, io_processor, communication_period=None) -> CLSimulator:
def set_io_processor(
self, io_processor: IOProcessor, communication_period=None
) -> CLSimulator:
"""Set simulator IO processor
Will replace any previous IOProcessor so there is only one at a time.
Expand Down Expand Up @@ -417,8 +462,8 @@ def set_io_processor(self, io_processor, communication_period=None) -> CLSimulat
def communicate_with_io_proc(t):
if io_processor.is_sampling_now(t / ms):
io_processor.put_state(self.get_state(), t / ms)
ctrl_signal = io_processor.get_ctrl_signal(t / ms)
self.update_stimulators(ctrl_signal)
stim_values = io_processor.get_stim_values(t / ms)
self.update_stimulators(stim_values)

# communication should be at every timestep. The IOProcessor
# decides when to sample and deliver results.
Expand Down Expand Up @@ -482,7 +527,7 @@ def to_neo(self) -> neo.core.Block:
block.groups.append(dev_neo)
elif isinstance(dev_neo, neo.core.dataobject.DataObject):
data_objects = [dev_neo]
cleo.utilities.add_to_neo_segment(seg, *data_objects)
add_to_neo_segment(seg, *data_objects)
return block


Expand Down Expand Up @@ -611,7 +656,7 @@ def connect_to_neuron_group(self, neuron_group: NeuronGroup, **kwparams) -> None
model=mod_syn_model,
on_pre=self.on_pre,
namespace=mod_syn_params,
name=f"syn_{self.name}_{neuron_group.name}",
name=f"syn_{brian_safe_name(self.name)}_{neuron_group.name}",
)
syn.namespace.update(self.extra_namespace)
syn.connect(i=i_sources, j=i_targets)
Expand Down
7 changes: 3 additions & 4 deletions cleo/coords.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
"""Contains functions for assigning neuron coordinates and visualizing"""

from __future__ import annotations

from typing import Tuple

from brian2 import Quantity, mm, meter, Unit, np
from brian2 import Quantity, Unit, meter, mm, np
from brian2.groups.group import Group
from brian2.groups.neurongroup import NeuronGroup
from brian2.units.fundamentalunits import get_dimensions
import numpy as np

from cleo.utilities import (
get_orth_vectors_for_V,
modify_model_with_eqs,
uniform_cylinder_rθz,
xyz_from_rθz,
Expand Down Expand Up @@ -47,7 +46,7 @@ def assign_coords_grid_rect_prism(
ValueError
When the shape is incompatible with the number of neurons in the group
"""
num_grid_elements = np.product(shape)
num_grid_elements = np.prod(shape)
if num_grid_elements != len(neuron_group):
raise ValueError(
f"Number of elements specified in shape ({num_grid_elements}) "
Expand Down
11 changes: 6 additions & 5 deletions cleo/ephys/lfp.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
"""Contains LFP signals"""
from __future__ import annotations
from typing import Any

from datetime import datetime
from typing import Any

import numpy as np
import quantities as pq
from attrs import define, field
from brian2 import NeuronGroup, mm, ms
from brian2.monitors.spikemonitor import SpikeMonitor
import numpy as np
from nptyping import NDArray
from tklfp import TKLFP
import quantities as pq

from cleo.base import NeoExportable
from cleo.ephys.probes import Signal, Probe
import cleo.utilities
from cleo.base import NeoExportable
from cleo.ephys.probes import Signal


@define(eq=False)
Expand Down
13 changes: 6 additions & 7 deletions cleo/ephys/probes.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
"""Contains Probe and Signal classes and electrode coordinate functions"""
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Iterable
from operator import concat
from typing import Any, Tuple

from attrs import field, define
from mpl_toolkits.mplot3d.axes3d import Axes3D
from matplotlib.artist import Artist
from brian2 import NeuronGroup, mm, Unit, Quantity, umeter, np
import neo
from attrs import define, field
from brian2 import NeuronGroup, Quantity, Unit, mm, np, umeter
from matplotlib.artist import Artist
from mpl_toolkits.mplot3d.axes3d import Axes3D

from cleo.base import Recorder, NeoExportable
from cleo.base import NeoExportable, Recorder
from cleo.coords import concat_coords
from cleo.utilities import get_orth_vectors_for_V

Expand Down
8 changes: 5 additions & 3 deletions cleo/imaging/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def add_self_to_plot(
y / axis_scale_unit,
z / axis_scale_unit,
color=color,
alpha=0.3,
alpha=0.2,
)

target_markers = ax.scatter(
Expand All @@ -384,15 +384,17 @@ def add_self_to_plot(
coords[:, 2] / axis_scale_unit,
marker="^",
c=color,
label=self.sensor.name,
label=f"{self.sensor.name} ROIs",
**kwargs,
)
color_rgba = target_markers.get_facecolor()
color_rgba[:, :3] = 0.3 * color_rgba[:, :3]
target_markers.set(color=color_rgba)
handles = ax.get_legend().legendHandles

handles = ax.get_legend().legendHandles
handles.append(target_markers)
patch = mpl.patches.Patch(color=color, label=self.name)
handles.append(patch)
ax.legend(handles=handles)

return [scope_marker, target_markers, plane]
Loading

0 comments on commit fe2cd33

Please sign in to comment.