Skip to content

Commit

Permalink
update experimental synapse models
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Mar 12, 2023
1 parent 7fb57de commit 3c9dd55
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 17 deletions.
11 changes: 7 additions & 4 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.3.6"
__version__ = "2.3.7"


# fundamental supporting modules
Expand Down Expand Up @@ -61,20 +61,23 @@
experimental,
)
from brainpy._src.dyn.base import not_pass_shared
from brainpy._src.dyn.base import (DynamicalSystem,
DynamicalSystemNS,
from brainpy._src.dyn.base import (DynamicalSystem as DynamicalSystem,
Container as Container,
Sequential as Sequential,
Network as Network,
NeuGroup as NeuGroup,
NeuGroupNS as NeuGroupNS,
SynConn as SynConn,
SynOut as SynOut,
SynSTP as SynSTP,
SynLTP as SynLTP,
TwoEndConn as TwoEndConn,
CondNeuGroup as CondNeuGroup,
Channel as Channel)
from brainpy._src.dyn.base import (DynamicalSystemNS as DynamicalSystemNS,
NeuGroupNS as NeuGroupNS)
from brainpy._src.dyn.synapses_v2.base import (SynOutNS as SynOutNS,
SynSTPNS as SynSTPNS,
SynConnNS as SynConnNS, )
from brainpy._src.dyn.transform import (LoopOverTime as LoopOverTime,)
from brainpy._src.dyn.runners import (DSRunner as DSRunner) # runner
from brainpy._src.dyn.context import share, Delay
Expand Down
103 changes: 103 additions & 0 deletions brainpy/_src/dyn/synapses_v2/abstract_synapses.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,3 +298,106 @@ def update(self, pre_spike, post_v=None):
return self.out(post_vs, post_v)
else:
return post_vs


class Alpha(DualExponential):
r"""Alpha synapse model.
**Model Descriptions**
The analytical expression of alpha synapse is given by:
.. math::
g_{syn}(t)= g_{max} \frac{t-t_{s}}{\tau} \exp \left(-\frac{t-t_{s}}{\tau}\right).
While, this equation is hard to implement. So, let's try to convert it into the
differential forms:
.. math::
\begin{aligned}
&g_{\mathrm{syn}}(t)= g_{\mathrm{max}} g \\
&\frac{d g}{d t}=-\frac{g}{\tau}+h \\
&\frac{d h}{d t}=-\frac{h}{\tau}+\delta\left(t_{0}-t\right)
\end{aligned}
**Model Examples**
.. plot::
:include-source: True
>>> import brainpy as bp
>>> from brainpy import neurons, synapses, synouts
>>> import matplotlib.pyplot as plt
>>>
>>> neu1 = neurons.LIF(1)
>>> neu2 = neurons.LIF(1)
>>> syn1 = synapses.Alpha(neu1, neu2, bp.connect.All2All(), output=synouts.CUBA())
>>> net = bp.Network(pre=neu1, syn=syn1, post=neu2)
>>>
>>> runner = bp.DSRunner(net, inputs=[('pre.input', 25.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.h'])
>>> runner.run(150.)
>>>
>>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8)
>>> fig.add_subplot(gs[0, 0])
>>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V')
>>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V')
>>> plt.legend()
>>> fig.add_subplot(gs[1, 0])
>>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g')
>>> plt.plot(runner.mon.ts, runner.mon['syn.h'], label='h')
>>> plt.legend()
>>> plt.show()
Parameters
----------
conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector
The synaptic connections.
comp_method: str
The connection type used for model speed optimization. It can be
`sparse` and `dense`. The default is `sparse`.
delay_step: int, ArrayType, Initializer, Callable
The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`.
tau_decay: float, ArrayType
The time constant of the synaptic decay phase. [ms]
g_max: float, ArrayType, Initializer, Callable
The synaptic strength (the maximum conductance). Default is 1.
name: str
The name of this synaptic projection.
method: str
The numerical integration methods.
References
----------
.. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw.
"The Synapse." Principles of Computational Modelling in Neuroscience.
Cambridge: Cambridge UP, 2011. 172-95. Print.
"""

def __init__(
self,
conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
out: Optional[SynOutNS] = None,
stp: Optional[SynSTPNS] = None,
comp_method: str = 'dense',
g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
tau_decay: Union[float, ArrayType] = 10.0,
method: str = 'exp_auto',

# other parameters
name: str = None,
mode: bm.Mode = None,
):
super().__init__(conn=conn,
comp_method=comp_method,
g_max=g_max,
tau_decay=tau_decay,
tau_rise=tau_decay,
method=method,
out=out,
stp=stp,
name=name,
mode=mode)

14 changes: 7 additions & 7 deletions brainpy/_src/dyn/synapses_v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from brainpy.types import ArrayType


class SynConn(DynamicalSystemNS):
class SynConnNS(DynamicalSystemNS):
def __init__(
self,
conn: TwoEndConnector,
out: Optional['SynOut'] = None,
stp: Optional['SynSTP'] = None,
out: Optional['SynOutNS'] = None,
stp: Optional['SynSTPNS'] = None,
name: str = None,
mode: bm.Mode = None,
):
Expand All @@ -28,8 +28,8 @@ def __init__(
self.post_size = conn.post_size
self.pre_num = conn.pre_num
self.post_num = conn.post_num
assert out is None or isinstance(out, SynOut)
assert stp is None or isinstance(stp, SynSTP)
assert out is None or isinstance(out, SynOutNS)
assert stp is None or isinstance(stp, SynSTPNS)
self.out = out
self.stp = stp

Expand Down Expand Up @@ -118,15 +118,15 @@ def _syn2post_with_dense(self, syn_value, syn_weight, conn_mat):
return post_vs


class SynOut(DynamicalSystemNS):
class SynOutNS(DynamicalSystemNS):
def update(self, post_g, post_v):
raise NotImplementedError

def reset_state(self, batch_size: Optional[int] = None):
pass


class SynSTP(DynamicalSystemNS):
class SynSTPNS(DynamicalSystemNS):
"""Base class for synaptic short-term plasticity."""

def update(self, pre_spike):
Expand Down
86 changes: 86 additions & 0 deletions brainpy/_src/dyn/synapses_v2/others.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@

from typing import Union, Optional

import brainpy.math as bm
from brainpy._src.dyn.base import DynamicalSystemNS
from brainpy._src.dyn.context import share
from brainpy.check import is_float, is_integer


class PoissonInput(DynamicalSystemNS):
"""Poisson Input.
Adds independent Poisson input to a target variable. For large
numbers of inputs, this is much more efficient than creating a
`PoissonGroup`. The synaptic events are generated randomly during the
simulation and are not preloaded and stored in memory. All the inputs must
target the same variable, have the same frequency and same synaptic weight.
All neurons in the target variable receive independent realizations of
Poisson spike trains.
Parameters
----------
num_input: int
The number of inputs.
freq: float
The frequency of each of the inputs. Must be a scalar.
weight: float
The synaptic weight. Must be a scalar.
"""

def __init__(
self,
target_shape,
num_input: int,
freq: Union[int, float],
weight: Union[int, float],
seed: Optional[int] = None,
mode: bm.Mode = None,
name: str = None
):
super(PoissonInput, self).__init__(name=name, mode=mode)

# check data
is_integer(num_input, 'num_input', min_bound=1)
is_float(freq, 'freq', min_bound=0., allow_int=True)
is_float(weight, 'weight', allow_int=True)
assert self.mode.is_parent_of(bm.NonBatchingMode, bm.BatchingMode)

# parameters
self.target_shape = target_shape
self.num_input = num_input
self.freq = freq
self.weight = weight
self.seed = seed
self.rng = bm.random.default_rng(seed)

def update(self):
p = self.freq * share.dt / 1e3
a = self.num_input * p
b = self.num_input * (1 - p)
if isinstance(share.dt, (int, float)): # dt is not in tracing
if (a > 5) and (b > 5):
inp = self.rng.normal(a, b * p, self.target_shape)
else:
inp = self.rng.binomial(self.num_input, p, self.target_shape)

else: # dt is in tracing
inp = bm.cond((a > 5) * (b > 5),
lambda _: self.rng.normal(a, b * p, self.target_shape),
lambda _: self.rng.binomial(self.num_input, p, self.target_shape),
None,
dyn_vars=self.rng)
return inp * self.weight

def __repr__(self):
names = self.__class__.__name__
return f'{names}(shape={self.target_shape}, num_input={self.num_input}, freq={self.freq}, weight={self.weight})'

def reset_state(self, batch_size=None):
pass

def reset(self, batch_size=None):
self.rng.seed(self.seed)
self.reset_state(batch_size)


13 changes: 7 additions & 6 deletions brainpy/experimental.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@

from brainpy._src.dyn.synapses_v2.base import (
SynConn as SynConn,
SynOut as SynOut,
SynSTP as SynSTP,
)
from brainpy._src.dyn.synapses_v2.syn_plasticity import (
STD as STD,
STP as STP,
Expand All @@ -13,5 +8,11 @@
COBA as COBA,
)
from brainpy._src.dyn.synapses_v2.abstract_synapses import (
Exponential as Exponential,
Exponential,
DualExponential,
Alpha,
)
from brainpy._src.dyn.synapses_v2.others import (
PoissonInput,
)

0 comments on commit 3c9dd55

Please sign in to comment.