Skip to content

Commit

Permalink
add neo support & tutorial (still need to revise docs) (#35)
Browse files Browse the repository at this point in the history
accidentally did the whole PR in one commit...
  • Loading branch information
kjohnsen authored Jul 17, 2023
1 parent 448792c commit dbeca8d
Show file tree
Hide file tree
Showing 20 changed files with 161,673 additions and 160,918 deletions.
49 changes: 47 additions & 2 deletions cleo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Tuple, Iterable
import datetime

from attrs import define, field
from brian2 import (
Expand All @@ -18,6 +19,24 @@
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.artist import Artist
import neo

import cleo.utilities


class NeoExportable(ABC):
"""Mixin class for classes that can be exported to Neo objects"""

@abstractmethod
def to_neo(self) -> neo.core.BaseNeo:
"""Return a neo.core.AnalogSignal object with the device's data
Returns
-------
neo.core.BaseNeo
Neo object representing exported data
"""
pass


@define(eq=False)
Expand Down Expand Up @@ -192,7 +211,7 @@ def get_state(self) -> Any:


@define(eq=False)
class Stimulator(InterfaceDevice):
class Stimulator(InterfaceDevice, NeoExportable):
"""Device for manipulating the network"""

value: Any = field(init=False, default=None)
Expand Down Expand Up @@ -238,9 +257,16 @@ def reset(self, **kwargs) -> None:
self.value = self.default_value
self._init_saved_vars()

def to_neo(self):
signal = cleo.utilities.analog_signal(self.t_ms, self.values, "dimensionless")
signal.name = self.name
signal.description = "Exported from Cleo stimulator device"
signal.annotate(export_datetime=datetime.datetime.now())
return signal


@define(eq=False)
class CLSimulator:
class CLSimulator(NeoExportable):
"""The centerpiece of cleo. Integrates simulation components and runs."""

network: Network = field(repr=False)
Expand Down Expand Up @@ -405,3 +431,22 @@ def reset(self, **kwargs):
device.reset(**kwargs)
if self.io_processor is not None:
self.io_processor.reset(**kwargs)

def to_neo(self) -> neo.core.Block:
block = neo.Block(
description="Exported from Cleo simulation",
)
block.annotate(export_datetime=datetime.datetime.now())
seg = neo.Segment()
block.segments.append(seg)
for device in self.devices:
if not isinstance(device, NeoExportable):
continue
dev_neo = device.to_neo()
if isinstance(dev_neo, neo.core.Group):
data_objects = dev_neo.data_children_recur
block.groups.append(dev_neo)
elif isinstance(dev_neo, neo.core.dataobject.DataObject):
data_objects = [dev_neo]
cleo.utilities.add_to_neo_segment(seg, *data_objects)
return block
30 changes: 28 additions & 2 deletions cleo/ephys/lfp.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
"""Contains LFP signals"""
from __future__ import annotations
from typing import Any
from datetime import datetime

from attrs import define, field
from brian2 import NeuronGroup, mm, ms
from brian2.monitors.spikemonitor import SpikeMonitor
import numpy as np
from nptyping import NDArray
from tklfp import TKLFP
import quantities as pq

from cleo.base import NeoExportable
from cleo.ephys.probes import Signal, Probe
import cleo.utilities


@define(eq=False)
class TKLFPSignal(Signal):
class TKLFPSignal(Signal, NeoExportable):
"""Records the Teleńczuk kernel LFP approximation.
Requires ``tklfp_type='exc'|'inh'`` to specify cell type
Expand All @@ -36,7 +40,7 @@ class TKLFPSignal(Signal):
to be considered, by default 1e-3.
This determines the buffer length of past spikes, since the uLFP from a long-past
spike becomes negligible and is ignored."""
save_history: bool = False
save_history: bool = True
"""Whether to record output from every timestep in :attr:`lfp_uV`.
Output is stored every time :meth:`get_state` is called."""
t_ms: NDArray[(Any,), float] = field(init=False, repr=False)
Expand Down Expand Up @@ -169,3 +173,25 @@ def _get_buffer_length(self, tklfp, **kwparams):
return np.ceil(
tklfp.compute_min_window_ms(self.uLFP_threshold_uV) / sample_period_ms
).astype(int)

def to_neo(self) -> neo.AnalogSignal:
# inherit docstring
try:
signal = cleo.utilities.analog_signal(
self.t_ms,
self.lfp_uV,
"uV",
)
except AttributeError:
return
signal.name = self.probe.name + "." + self.name
signal.description = f"Exported from Cleo {self.__class__.__name__} object"
signal.annotate(export_datetime=datetime.now())
# broadcast in case of uniform direction
signal.array_annotate(
x=self.probe.coords[..., 0] / mm * pq.mm,
y=self.probe.coords[..., 1] / mm * pq.mm,
z=self.probe.coords[..., 2] / mm * pq.mm,
i_channel=np.arange(self.probe.n),
)
return signal
15 changes: 13 additions & 2 deletions cleo/ephys/probes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from mpl_toolkits.mplot3d.axes3d import Axes3D
from matplotlib.artist import Artist
from brian2 import NeuronGroup, mm, Unit, Quantity, umeter, np
import neo

from cleo.base import Recorder
from cleo.base import Recorder, NeoExportable
from cleo.utilities import get_orth_vectors_for_V


Expand Down Expand Up @@ -78,7 +79,7 @@ def reset(self, **kwargs) -> None:


@define(eq=False)
class Probe(Recorder):
class Probe(Recorder, NeoExportable):
"""Picks up specified signals across an array of electrodes.
Visualization kwargs
Expand Down Expand Up @@ -226,6 +227,16 @@ def reset(self, **kwargs):
for signal in self.signals:
signal.reset()

def to_neo(self) -> neo.core.Group:
group = neo.core.Group(
name=self.name, description="Exported from Cleo Probe device"
)
for sig in self.signals:
if not isinstance(sig, NeoExportable):
continue
group.add(sig.to_neo())
return group


def concat_coords(*coords: Quantity) -> Quantity:
"""Combine multiple coordinate Quantity arrays into one
Expand Down
47 changes: 38 additions & 9 deletions cleo/ephys/spiking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,34 @@
from __future__ import annotations
from abc import abstractmethod
from typing import Any, Tuple
from datetime import datetime

from attrs import define, field, fields
from bidict import bidict
from brian2 import NeuronGroup, Quantity, SpikeMonitor, meter, ms
from brian2 import NeuronGroup, Quantity, SpikeMonitor, meter, ms, mm
import numpy as np

# import numpy.typing as npt
import neo
import quantities as pq
from nptyping import NDArray

from cleo.base import NeoExportable
from cleo.ephys.probes import Signal


@define(eq=False)
class Spiking(Signal):
class Spiking(Signal, NeoExportable):
"""Base class for probabilistically detecting spikes"""

perfect_detection_radius: Quantity
r_perfect_detection: Quantity
"""Radius (with Brian unit) within which all spikes
are detected"""
half_detection_radius: Quantity
r_half_detection: Quantity
"""Radius (with Brian unit) within which half of all spikes
are detected"""
cutoff_probability: float = 0.01
"""Spike detection probability below which neurons will not be
considered. For computational efficiency."""
save_history: bool = False
save_history: bool = True
"""Determines whether :attr:`t_ms`, :attr:`i`, and :attr:`t_samp_ms` are recorded"""
t_ms: NDArray[(Any,), float] = field(
init=False, factory=lambda: np.array([], dtype=float), repr=False
Expand Down Expand Up @@ -136,8 +138,8 @@ def get_state(

def _detection_prob_for_distance(self, r: Quantity) -> float:
# p(d) = h/(r-c)
a = self.perfect_detection_radius
b = self.half_detection_radius
a = self.r_perfect_detection
b = self.r_half_detection
# solving for p(a) = 1 and p(b) = .5 yields:
c = 2 * a - b
h = b - a
Expand Down Expand Up @@ -174,6 +176,21 @@ def reset(self, **kwargs) -> None:
self._mon_spikes_already_seen[j] = mon.num_spikes
self._init_saved_vars()

def to_neo(self) -> neo.Group:
group = neo.Group(allowed_types=[neo.SpikeTrain])
for i in set(self.i):
st = neo.SpikeTrain(
times=self.t_ms[self.i == i] * pq.ms,
t_stop=self.probe.sim.network.t / ms * pq.ms,
)
st.annotate(i=int(i))
group.add(st)

group.annotate(export_datetime=datetime.now())
group.name = f"{self.probe.name}.{self.name}"
group.description = f"Exported from Cleo {self.__class__.__name__} object"
return group


@define(eq=False)
class MultiUnitSpiking(Spiking):
Expand Down Expand Up @@ -218,6 +235,18 @@ def _noisily_detect_spikes_per_channel(
t_detected = t[i_spike_detected]
return i_c_detected, t_detected, y

def to_neo(self) -> neo.Group:
group = super().to_neo()
for st in group.spiketrains:
i = int(st.annotations["i"])
st.annotate(
i_channel=i,
x_contact=self.probe.coords[i, 0] / mm * pq.mm,
y_contact=self.probe.coords[i, 1] / mm * pq.mm,
z_contact=self.probe.coords[i, 2] / mm * pq.mm,
)
return group


@define(eq=False)
class SortedSpiking(Spiking):
Expand Down
Loading

0 comments on commit dbeca8d

Please sign in to comment.