From 05b3a7b203536fddfc35233d1ee4bd09c1c89f44 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 12 Mar 2023 21:38:33 +0800 Subject: [PATCH 1/8] support right shift to call a module --- brainpy/_src/dyn/base.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/brainpy/_src/dyn/base.py b/brainpy/_src/dyn/base.py index fb76ce000..6d4baf287 100644 --- a/brainpy/_src/dyn/base.py +++ b/brainpy/_src/dyn/base.py @@ -400,6 +400,20 @@ def __del__(self): def clear_input(self): pass + def __rrshift__(self, other): + """Support using right shift operator to call modules. + + Examples + -------- + + >>> import brainpy as bp + >>> x = bp.math.random.rand((10, 10)) + >>> l = bp.layers.Activation('tanh') + >>> y = x >> l + + """ + return self.__call__(other) + class DynamicalSystemNS(DynamicalSystem): """Dynamical system without the need of shared parameters passing into ``update()`` function.""" From 21afd08554b2d191071c2e00e2f764d85751c2f2 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 12 Mar 2023 21:56:12 +0800 Subject: [PATCH 2/8] fix population model bugs --- brainpy/_src/dyn/rates/populations.py | 322 +++++++++++++++++--------- 1 file changed, 214 insertions(+), 108 deletions(-) diff --git a/brainpy/_src/dyn/rates/populations.py b/brainpy/_src/dyn/rates/populations.py index 148bce097..647308d9f 100644 --- a/brainpy/_src/dyn/rates/populations.py +++ b/brainpy/_src/dyn/rates/populations.py @@ -2,13 +2,19 @@ from typing import Union, Callable -from brainpy import check, math as bm -from brainpy._src.dyn.base import NeuGroupNS as NeuGroup +from brainpy import math as bm +from brainpy import share +from brainpy._src.dyn.base import NeuGroupNS from brainpy._src.dyn.neurons.noise_groups import OUProcess -from brainpy._src.initialize import Initializer, Uniform, parameter, variable, ZeroInit +from brainpy._src.initialize import (Initializer, + Uniform, + parameter, + variable, + variable_, + ZeroInit) from brainpy._src.integrators.joint_eq import JointEq from brainpy._src.integrators.ode.generic import odeint -from brainpy.check import is_float, is_initializer, jit_error_checking +from brainpy.check import is_initializer from brainpy.types import Shape, ArrayType __all__ = [ @@ -22,7 +28,7 @@ ] -class RateModel(NeuGroup): +class RateModel(NeuGroupNS): pass @@ -90,6 +96,7 @@ def __init__( # parameter for training mode: bm.Mode = None, + input_var: bool = True, ): super(FHN, self).__init__(size=size, name=name, @@ -113,6 +120,7 @@ def __init__( allow_none=False) # ms, timescale of the Ornstein-Uhlenbeck noise process self.y_ou_tau = parameter(y_ou_tau, self.varshape, allow_none=False) # ms, timescale of the Ornstein-Uhlenbeck noise process + self.input_var = input_var # initializers is_initializer(x_initializer, 'x_initializer') @@ -121,10 +129,11 @@ def __init__( self._y_initializer = y_initializer # variables - self.x = variable(x_initializer, self.mode, self.varshape) - self.y = variable(y_initializer, self.mode, self.varshape) - self.input = variable(bm.zeros, self.mode, self.varshape) - self.input_y = variable(bm.zeros, self.mode, self.varshape) + self.x = variable_(self._x_initializer, self.varshape, self.mode) + self.y = variable_(self._y_initializer, self.varshape, self.mode) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, self.mode) + self.input_y = variable_(bm.zeros, self.varshape, self.mode) # noise variables self.x_ou = self.y_ou = None @@ -142,13 +151,14 @@ def __init__( method=method) # integral functions - self.integral = odeint(f=JointEq([self.dx, self.dy]), method=method) + self.integral = odeint(f=JointEq(self.dx, self.dy), method=method) def reset_state(self, batch_size=None): self.x.value = variable(self._x_initializer, batch_size, self.varshape) self.y.value = variable(self._y_initializer, batch_size, self.varshape) - self.input.value = variable(bm.zeros, batch_size, self.varshape) - self.input_y.value = variable(bm.zeros, batch_size, self.varshape) + if self.input_var: + self.input.value = variable(bm.zeros, batch_size, self.varshape) + self.input_y.value = variable(bm.zeros, batch_size, self.varshape) if self.x_ou is not None: self.x_ou.reset_state(batch_size) if self.y_ou is not None: @@ -160,25 +170,38 @@ def dx(self, x, t, y, x_ext): def dy(self, y, t, x, y_ext=0.): return (x - self.delta - self.epsilon * y) / self.tau + y_ext - def update(self, tdi, x=None): - t, dt = tdi['t'], tdi['dt'] + def update(self, x1=None, x2=None): + t = share.load('t') + dt = share.load('dt') # input - if x is not None: - self.input += x - if self.x_ou is not None: - self.input += self.x_ou() - if self.y_ou is not None: - self.input_y += self.y_ou() + if self.input_var: + if x1 is not None: + self.input += x1 + if self.x_ou is not None: + self.input += self.x_ou() + if x2 is not None: + self.input_y += x2 + if self.y_ou is not None: + self.input_y += self.y_ou() + input_x = self.input.value + input_y = self.input_y.value + else: + input_x = x1 if (x1 is not None) else 0. + if self.x_ou is not None: input_x += self.x_ou() + input_y = x2 if (x2 is not None) else 0. + if self.y_ou is not None: input_y += self.y_ou() # integral - x, y = self.integral(self.x, self.y, t, x_ext=self.input, y_ext=self.input_y, dt=dt) + x, y = self.integral(self.x.value, self.y.value, t, x_ext=input_x, y_ext=input_y, dt=dt) self.x.value = x self.y.value = y + return x def clear_input(self): - self.input.value = bm.zeros_like(self.input) - self.input_y.value = bm.zeros_like(self.input_y) + if self.input_var: + self.input.value = bm.zeros_like(self.input) + self.input_y.value = bm.zeros_like(self.input_y) class FeedbackFHN(RateModel): @@ -268,20 +291,16 @@ def __init__( y_initializer: Union[Initializer, Callable, ArrayType] = Uniform(0, 0.05), method: str = 'exp_auto', name: str = None, - dt: float = None, # parameter for training mode: bm.Mode = None, + input_var: bool = True, ): super(FeedbackFHN, self).__init__(size=size, name=name, keep_size=keep_size, mode=mode) - # dt - self.dt = bm.get_dt() if dt is None else dt - is_float(self.dt, 'dt', allow_none=False, min_bound=0., allow_int=False) - # parameters self.a = parameter(a, self.varshape, allow_none=False) self.b = parameter(b, self.varshape, allow_none=False) @@ -297,6 +316,7 @@ def __init__( self.y_ou_sigma = parameter(y_ou_sigma, self.varshape, allow_none=False) self.x_ou_tau = parameter(x_ou_tau, self.varshape, allow_none=False) self.y_ou_tau = parameter(y_ou_tau, self.varshape, allow_none=False) + self.input_var = input_var # initializers is_initializer(x_initializer, 'x_initializer') @@ -307,9 +327,10 @@ def __init__( # variables self.x = variable(x_initializer, self.mode, self.varshape) self.y = variable(y_initializer, self.mode, self.varshape) - self.x_delay = bm.TimeDelay(self.x, self.delay, dt=self.dt, interp_method='round') - self.input = variable(bm.zeros, self.mode, self.varshape) - self.input_y = variable(bm.zeros, self.mode, self.varshape) + self.x_delay = bm.TimeDelay(self.x, self.delay, dt=bm.dt, interp_method='round') + if self.input_var: + self.input = variable(bm.zeros, self.mode, self.varshape) + self.input_y = variable(bm.zeros, self.mode, self.varshape) # noise variables self.x_ou = self.y_ou = None @@ -335,8 +356,9 @@ def reset_state(self, batch_size=None): self.x.value = variable(self._x_initializer, batch_size, self.varshape) self.y.value = variable(self._y_initializer, batch_size, self.varshape) self.x_delay.reset(self.x, self.delay) - self.input = variable(bm.zeros, batch_size, self.varshape) - self.input_y = variable(bm.zeros, batch_size, self.varshape) + if self.input_var: + self.input = variable(bm.zeros, batch_size, self.varshape) + self.input_y = variable(bm.zeros, batch_size, self.varshape) if self.x_ou is not None: self.x_ou.reset_state(batch_size) if self.y_ou is not None: @@ -348,30 +370,37 @@ def dx(self, x, t, y, x_ext): def dy(self, y, t, x, y_ext): return (x + self.a - self.b * y + y_ext) / self.tau - def _check_dt(self, dt): - raise ValueError(f'The "dt" {dt} used in model running is ' - f'not consistent with the "dt" {self.dt} ' - f'used in model definition.') + def update(self, x1=None, x2=None): + t = share.load('t') + dt = share.load('dt') - def update(self, tdi, x=None): - t = tdi['t'] - dt = tdi['dt'] - if check.is_checking(): - jit_error_checking(not bm.isclose(dt, self.dt), self._check_dt, dt) - - if x is not None: self.input += x - if self.x_ou is not None: - self.input += self.x_ou() - if self.y_ou is not None: - self.input_y += self.y_ou() - - x, y = self.integral(self.x, self.y, t, x_ext=self.input, y_ext=self.input_y, dt=dt) + # input + if self.input_var: + if x1 is not None: + self.input += x1 + if self.x_ou is not None: + self.input += self.x_ou() + if x2 is not None: + self.input_y += x2 + if self.y_ou is not None: + self.input_y += self.y_ou() + input_x = self.input.value + input_y = self.input_y.value + else: + input_x = x1 if (x1 is not None) else 0. + if self.x_ou is not None: input_x += self.x_ou() + input_y = x2 if (x2 is not None) else 0. + if self.y_ou is not None: input_y += self.y_ou() + + x, y = self.integral(self.x.value, self.y.value, t, x_ext=input_x, y_ext=input_y, dt=dt) self.x.value = x self.y.value = y + return x def clear_input(self): - self.input.value = bm.zeros_like(self.input) - self.input_y.value = bm.zeros_like(self.input_y) + if self.input_var: + self.input.value = bm.zeros_like(self.input) + self.input_y.value = bm.zeros_like(self.input_y) class QIF(RateModel): @@ -464,6 +493,7 @@ def __init__( y_initializer: Union[Initializer, Callable, ArrayType] = Uniform(0, 0.05), method: str = 'exp_auto', name: str = None, + input_var: bool = True, # parameter for training mode: bm.Mode = None, @@ -489,6 +519,7 @@ def __init__( self.y_ou_sigma = parameter(y_ou_sigma, self.varshape, allow_none=False) self.x_ou_tau = parameter(x_ou_tau, self.varshape, allow_none=False) self.y_ou_tau = parameter(y_ou_tau, self.varshape, allow_none=False) + self.input_var = input_var # initializers is_initializer(x_initializer, 'x_initializer') @@ -499,8 +530,9 @@ def __init__( # variables self.x = variable(x_initializer, self.mode, self.varshape) self.y = variable(y_initializer, self.mode, self.varshape) - self.input = variable(bm.zeros, self.mode, self.varshape) - self.input_y = variable(bm.zeros, self.mode, self.varshape) + if self.input_var: + self.input = variable(bm.zeros, self.mode, self.varshape) + self.input_y = variable(bm.zeros, self.mode, self.varshape) # noise variables self.x_ou = self.y_ou = None @@ -523,8 +555,9 @@ def __init__( def reset_state(self, batch_size=None): self.x.value = variable(self._x_initializer, batch_size, self.varshape) self.y.value = variable(self._y_initializer, batch_size, self.varshape) - self.input.value = variable(bm.zeros, batch_size, self.varshape) - self.input_y.value = variable(bm.zeros, batch_size, self.varshape) + if self.input_var: + self.input.value = variable(bm.zeros, batch_size, self.varshape) + self.input_y.value = variable(bm.zeros, batch_size, self.varshape) if self.x_ou is not None: self.x_ou.reset_state(batch_size) if self.y_ou is not None: @@ -537,22 +570,37 @@ def dx(self, x, t, y, x_ext): return (x ** 2 + self.eta + x_ext + self.J * y * self.tau - (bm.pi * y * self.tau) ** 2) / self.tau - def update(self, tdi, x=None): - t, dt = tdi['t'], tdi['dt'] + def update(self, x1=None, x2=None): + t = share.load('t') + dt = share.load('dt') - if x is not None: self.input += x - if self.x_ou is not None: - self.input += self.x_ou() - if self.y_ou is not None: - self.input_y += self.y_ou() - - x, y = self.integral(self.x, self.y, t=t, x_ext=self.input, y_ext=self.input_y, dt=dt) + # input + if self.input_var: + if x1 is not None: + self.input += x1 + if self.x_ou is not None: + self.input += self.x_ou() + if x2 is not None: + self.input_y += x2 + if self.y_ou is not None: + self.input_y += self.y_ou() + input_x = self.input.value + input_y = self.input_y.value + else: + input_x = x1 if (x1 is not None) else 0. + if self.x_ou is not None: input_x += self.x_ou() + input_y = x2 if (x2 is not None) else 0. + if self.y_ou is not None: input_y += self.y_ou() + + x, y = self.integral(self.x, self.y, t=t, x_ext=input_x, y_ext=input_y, dt=dt) self.x.value = x self.y.value = y + return x def clear_input(self): - self.input.value = bm.zeros_like(self.input) - self.input_y.value = bm.zeros_like(self.input_y) + if self.input_var: + self.input.value = bm.zeros_like(self.input) + self.input_y.value = bm.zeros_like(self.input_y) class StuartLandauOscillator(RateModel): @@ -606,6 +654,7 @@ def __init__( # parameter for training mode: bm.Mode = None, + input_var: bool = True, ): super(StuartLandauOscillator, self).__init__(size=size, name=name, @@ -623,6 +672,7 @@ def __init__( self.y_ou_sigma = parameter(y_ou_sigma, self.varshape, allow_none=False) self.x_ou_tau = parameter(x_ou_tau, self.varshape, allow_none=False) self.y_ou_tau = parameter(y_ou_tau, self.varshape, allow_none=False) + self.input_var = input_var # initializers is_initializer(x_initializer, 'x_initializer') @@ -633,8 +683,9 @@ def __init__( # variables self.x = variable(x_initializer, self.mode, self.varshape) self.y = variable(y_initializer, self.mode, self.varshape) - self.input = variable(bm.zeros, self.mode, self.varshape) - self.input_y = variable(bm.zeros, self.mode, self.varshape) + if input_var: + self.input = variable(bm.zeros, self.mode, self.varshape) + self.input_y = variable(bm.zeros, self.mode, self.varshape) # noise variables self.x_ou = self.y_ou = None @@ -657,8 +708,9 @@ def __init__( def reset_state(self, batch_size=None): self.x.value = variable(self._x_initializer, batch_size, self.varshape) self.y.value = variable(self._y_initializer, batch_size, self.varshape) - self.input.value = variable(bm.zeros, batch_size, self.varshape) - self.input_y.value = variable(bm.zeros, batch_size, self.varshape) + if self.input_var: + self.input.value = variable(bm.zeros, batch_size, self.varshape) + self.input_y.value = variable(bm.zeros, batch_size, self.varshape) if self.x_ou is not None: self.x_ou.reset_state(batch_size) if self.y_ou is not None: @@ -670,29 +722,44 @@ def dx(self, x, t, y, x_ext, a, w): def dy(self, y, t, x, y_ext, a, w): return (a - x * x - y * y) * y - w * y + y_ext - def update(self, tdi, x=None): - t, dt = tdi['t'], tdi['dt'] + def update(self, x1=None, x2=None): + t = share.load('t') + dt = share.load('dt') - if x is not None: self.input += x - if self.x_ou is not None: - self.input += self.x_ou() - if self.y_ou is not None: - self.input_y += self.y_ou() + # input + if self.input_var: + if x1 is not None: + self.input += x1 + if self.x_ou is not None: + self.input += self.x_ou() + if x2 is not None: + self.input_y += x2 + if self.y_ou is not None: + self.input_y += self.y_ou() + input_x = self.input.value + input_y = self.input_y.value + else: + input_x = x1 if (x1 is not None) else 0. + if self.x_ou is not None: input_x += self.x_ou() + input_y = x2 if (x2 is not None) else 0. + if self.y_ou is not None: input_y += self.y_ou() x, y = self.integral(self.x, self.y, t=t, - x_ext=self.input, - y_ext=self.input_y, + x_ext=input_x, + y_ext=input_y, a=self.a, w=self.w, dt=dt) self.x.value = x self.y.value = y + return x def clear_input(self): - self.input.value = bm.zeros_like(self.input) - self.input_y.value = bm.zeros_like(self.input_y) + if self.input_var: + self.input.value = bm.zeros_like(self.input) + self.input_y.value = bm.zeros_like(self.input_y) class WilsonCowanModel(RateModel): @@ -759,8 +826,9 @@ def __init__( # parameter for training mode: bm.Mode = None, + input_var: bool = True, ): - super(WilsonCowanModel, self).__init__(size=size, name=name, keep_size=keep_size) + super(WilsonCowanModel, self).__init__(size=size, name=name, keep_size=keep_size, mode=mode) # model parameters self.E_a = parameter(E_a, self.varshape, allow_none=False) @@ -774,6 +842,7 @@ def __init__( self.wEI = parameter(wEI, self.varshape, allow_none=False) self.wII = parameter(wII, self.varshape, allow_none=False) self.r = parameter(r, self.varshape, allow_none=False) + self.input_var = input_var # noise parameters self.x_ou_mean = parameter(x_ou_mean, self.varshape, allow_none=False) @@ -792,8 +861,9 @@ def __init__( # variables self.x = variable(x_initializer, self.mode, self.varshape) self.y = variable(y_initializer, self.mode, self.varshape) - self.input = variable(bm.zeros, self.mode, self.varshape) - self.input_y = variable(bm.zeros, self.mode, self.varshape) + if self.input_var: + self.input = variable(bm.zeros, self.mode, self.varshape) + self.input_y = variable(bm.zeros, self.mode, self.varshape) # noise variables self.x_ou = self.y_ou = None @@ -816,8 +886,9 @@ def __init__( def reset_state(self, batch_size=None): self.x.value = variable(self._x_initializer, batch_size, self.varshape) self.y.value = variable(self._y_initializer, batch_size, self.varshape) - self.input.value = variable(bm.zeros, batch_size, self.varshape) - self.input_y.value = variable(bm.zeros, batch_size, self.varshape) + if self.input_var: + self.input.value = variable(bm.zeros, batch_size, self.varshape) + self.input_y.value = variable(bm.zeros, batch_size, self.varshape) if self.x_ou is not None: self.x_ou.reset_state(batch_size) if self.y_ou is not None: @@ -834,20 +905,37 @@ def dy(self, y, t, x, y_ext): x = self.wEI * x - self.wII * y + y_ext return (-y + (1 - self.r * y) * self.F(x, self.I_a, self.I_theta)) / self.I_tau - def update(self, tdi, x=None): - t, dt = tdi['t'], tdi['dt'] - if x is not None: self.input += x - if self.x_ou is not None: - self.input += self.x_ou() - if self.y_ou is not None: - self.input_y += self.y_ou() - x, y = self.integral(self.x, self.y, t, x_ext=self.input, y_ext=self.input_y, dt=dt) + def update(self, x1=None, x2=None): + t = share.load('t') + dt = share.load('dt') + + # input + if self.input_var: + if x1 is not None: + self.input += x1 + if self.x_ou is not None: + self.input += self.x_ou() + if x2 is not None: + self.input_y += x2 + if self.y_ou is not None: + self.input_y += self.y_ou() + input_x = self.input.value + input_y = self.input_y.value + else: + input_x = x1 if (x1 is not None) else 0. + if self.x_ou is not None: input_x += self.x_ou() + input_y = x2 if (x2 is not None) else 0. + if self.y_ou is not None: input_y += self.y_ou() + + x, y = self.integral(self.x, self.y, t, x_ext=input_x, y_ext=input_y, dt=dt) self.x.value = x self.y.value = y + return x def clear_input(self): - self.input.value = bm.zeros_like(self.input) - self.input_y.value = bm.zeros_like(self.input_y) + if self.input_var: + self.input.value = bm.zeros_like(self.input) + self.input_y.value = bm.zeros_like(self.input_y) class JansenRitModel(RateModel): @@ -913,6 +1001,7 @@ def __init__( # parameter for training mode: bm.Mode = None, + input_var: bool = True, ): super(ThresholdLinearModel, self).__init__(size, name=name, @@ -929,12 +1018,14 @@ def __init__( self.noise_i = parameter(noise_i, self.varshape, False) self._e_initializer = e_initializer self._i_initializer = i_initializer + self.input_var = input_var # variables self.e = variable(e_initializer, self.mode, self.varshape) # Firing rate of excitatory population self.i = variable(i_initializer, self.mode, self.varshape) # Firing rate of inhibitory population - self.Ie = variable(bm.zeros, self.mode, self.varshape) # Input of excitaory population - self.Ii = variable(bm.zeros, self.mode, self.varshape) # Input of inhibitory population + if self.input_var: + self.Ie = variable(bm.zeros, self.mode, self.varshape) # Input of excitaory population + self.Ii = variable(bm.zeros, self.mode, self.varshape) # Input of inhibitory population if bm.any(self.noise_e != 0) or bm.any(self.noise_i != 0): self.rng = bm.random.default_rng(seed) @@ -945,25 +1036,40 @@ def reset(self, batch_size=None): def reset_state(self, batch_size=None): self.e.value = variable(self._e_initializer, batch_size, self.varshape) self.i.value = variable(self._i_initializer, batch_size, self.varshape) - self.Ie.value = variable(bm.zeros, batch_size, self.varshape) - self.Ii.value = variable(bm.zeros, batch_size, self.varshape) + if self.input_var: + self.Ie.value = variable(bm.zeros, batch_size, self.varshape) + self.Ii.value = variable(bm.zeros, batch_size, self.varshape) - def update(self, tdi, x=None): - t, dt = tdi['t'], tdi['dt'] + def update(self, x1=None, x2=None): + t = share.load('t') + dt = share.load('dt') - if x is not None: self.Ie += x - de = -self.e + self.beta_e * bm.maximum(self.Ie, 0.) + # input + if self.input_var: + if x1 is not None: + self.Ie += x1 + if x2 is not None: + self.Ii += x2 + input_e = self.Ie.value + input_i = self.Ii.value + else: + input_e = x1 if (x1 is not None) else 0. + input_i = x2 if (x2 is not None) else 0. + + de = -self.e + self.beta_e * bm.maximum(input_e, 0.) if bm.any(self.noise_e != 0.): de += self.rng.randn(self.varshape) * self.noise_e de = de / self.tau_e self.e.value = bm.maximum(self.e + de * dt, 0.) - di = -self.i + self.beta_i * bm.maximum(self.Ii, 0.) + di = -self.i + self.beta_i * bm.maximum(input_i, 0.) if bm.any(self.noise_i != 0.): di += self.rng.randn(self.varshape) * self.noise_i di = di / self.tau_i self.i.value = bm.maximum(self.i + di * dt, 0.) + return self.e.value def clear_input(self): - self.Ie.value = bm.zeros_like(self.Ie) - self.Ii.value = bm.zeros_like(self.Ii) + if self.input_var: + self.Ie.value = bm.zeros_like(self.Ie) + self.Ii.value = bm.zeros_like(self.Ii) From 1a416af8ce56963119fb8176343bdd8dc8180f1c Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 12 Mar 2023 21:56:27 +0800 Subject: [PATCH 3/8] update neuron models --- brainpy/_src/dyn/neurons/biological_models.py | 7 +- brainpy/_src/dyn/neurons/input_groups.py | 2 +- brainpy/_src/dyn/neurons/noise_groups.py | 2 +- brainpy/_src/dyn/neurons/reduced_models.py | 196 +++++++++++++++--- 4 files changed, 178 insertions(+), 29 deletions(-) diff --git a/brainpy/_src/dyn/neurons/biological_models.py b/brainpy/_src/dyn/neurons/biological_models.py index fefe9253d..ef27bec35 100644 --- a/brainpy/_src/dyn/neurons/biological_models.py +++ b/brainpy/_src/dyn/neurons/biological_models.py @@ -6,7 +6,12 @@ from brainpy import check from brainpy._src.dyn.base import NeuGroupNS from brainpy._src.dyn.context import share -from brainpy._src.initialize import OneInit, Uniform, Initializer, parameter, noise as init_noise, variable_ +from brainpy._src.initialize import (OneInit, + Uniform, + Initializer, + parameter, + noise as init_noise, + variable_) from brainpy._src.integrators.joint_eq import JointEq from brainpy._src.integrators.ode.generic import odeint from brainpy._src.integrators.sde.generic import sdeint diff --git a/brainpy/_src/dyn/neurons/input_groups.py b/brainpy/_src/dyn/neurons/input_groups.py index 833d2eb9f..512a0dc3e 100644 --- a/brainpy/_src/dyn/neurons/input_groups.py +++ b/brainpy/_src/dyn/neurons/input_groups.py @@ -5,7 +5,7 @@ import jax.numpy as jnp from brainpy._src.dyn.context import share import brainpy.math as bm -from brainpy._src.dyn.base import NeuGroupNS, not_pass_shared +from brainpy._src.dyn.base import NeuGroupNS from brainpy._src.initialize import Initializer, parameter, variable_ from brainpy.types import Shape, ArrayType diff --git a/brainpy/_src/dyn/neurons/noise_groups.py b/brainpy/_src/dyn/neurons/noise_groups.py index 3c6c14f40..93916fab0 100644 --- a/brainpy/_src/dyn/neurons/noise_groups.py +++ b/brainpy/_src/dyn/neurons/noise_groups.py @@ -5,7 +5,7 @@ import jax.numpy as jnp from brainpy._src.dyn.context import share from brainpy import math as bm, initialize as init -from brainpy._src.dyn.base import NeuGroupNS as NeuGroup, not_pass_shared +from brainpy._src.dyn.base import NeuGroupNS as NeuGroup from brainpy._src.initialize import Initializer from brainpy._src.integrators.sde.generic import sdeint from brainpy.types import ArrayType, Shape diff --git a/brainpy/_src/dyn/neurons/reduced_models.py b/brainpy/_src/dyn/neurons/reduced_models.py index f17a2c30c..3cecd7aca 100644 --- a/brainpy/_src/dyn/neurons/reduced_models.py +++ b/brainpy/_src/dyn/neurons/reduced_models.py @@ -6,7 +6,7 @@ from jax.lax import stop_gradient import brainpy.math as bm -from brainpy._src.dyn.base import NeuGroupNS as NeuGroup, not_pass_shared +from brainpy._src.dyn.base import NeuGroupNS from brainpy._src.dyn.context import share from brainpy._src.initialize import (ZeroInit, OneInit, @@ -33,16 +33,84 @@ ] -class LeakyIntegrator(NeuGroup): +class Leaky(NeuGroupNS): r"""Leaky Integrator Model. **Model Descriptions** - This class implements a leaky integrator model, in which its dynamics is + This class implements a leaky model, in which its dynamics is given by: .. math:: + x(t + \Delta t) = \exp{-1/\tau \Delta t} x(t) + I + + Parameters + ---------- + size: sequence of int, int + The size of the neuron group. + tau: float, ArrayType, Initializer, callable + Membrane time constant. + method: str + The numerical integration method. + name: str + The group name. + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + tau: Union[float, ArrayType, Initializer, Callable] = 10., + name: str = None, + mode: bm.Mode = None, + method: str = 'exp_auto', + ): + super().__init__(size=size, + mode=mode, + keep_size=keep_size, + name=name) + assert self.mode.is_parent_of(bm.TrainingMode, bm.NonBatchingMode) + + # parameters + self.tau = parameter(tau, self.varshape, allow_none=False) + + # integral + self.integral = odeint(method=method, f=self.derivative) + + # variables + self.reset_state(self.mode) + + def derivative(self, x, t): + return -x / self.tau + + def reset_state(self, batch_size=None): + self.x = variable_(bm.zeros, self.varshape, batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + r = self.integral(self.x.value, t, dt) + if x is not None: + r += x + self.x.value = r + return r + + def clear_input(self): + if self.input_var: + self.input[:] = 0. + + +class LeakyIntegrator(NeuGroupNS): + r"""Leaky Integrator Model. + + **Model Descriptions** + + This class implements a leaky integrator model, in which its dynamics is + given by: + + .. math:: + \tau \frac{dV}{dt} = - (V(t) - V_{rest}) + RI(t) where :math:`V` is the membrane potential, :math:`V_{rest}` is the resting @@ -122,7 +190,6 @@ def reset_state(self, batch_size=None): if self.input_var: self.input = variable_(bm.zeros, self.varshape, batch_size) - @not_pass_shared def update(self, x=None): t = share.load('t') dt = share.load('dt') @@ -140,7 +207,95 @@ def clear_input(self): self.input[:] = 0. -class LIF(NeuGroup): +class Integrator(NeuGroupNS): + r"""Integrator Model. + + This class implements an integrator model, in which its dynamics is + given by: + + .. math:: + + \tau \frac{dx}{dt} = - x(t) + I(t) + + where :math:`x` is the integrator value, and :math:`\tau` is the time constant. + + Parameters + ---------- + size: sequence of int, int + The size of the neuron group. + tau: float, ArrayType, Initializer, callable + Membrane time constant. + x_initializer: ArrayType, Initializer, callable + The initializer of :math:`x`. + noise: ArrayType, Initializer, callable + The noise added onto the membrane potential + method: str + The numerical integration method. + name: str + The group name. + """ + + def __init__( + self, + size: Shape, + keep_size: bool = False, + tau: Union[float, ArrayType, Initializer, Callable] = 10., + x_initializer: Union[Initializer, Callable, ArrayType] = ZeroInit(), + noise: Union[float, ArrayType, Initializer, Callable] = None, + input_var: bool = False, + name: str = None, + mode: bm.Mode = None, + method: str = 'exp_auto', + ): + super().__init__(size=size, + mode=mode, + keep_size=keep_size, + name=name) + is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode)) + + # parameters + self.tau = parameter(tau, self.varshape, allow_none=False) + self.noise = init_noise(noise, self.varshape) + self.input_var = input_var + + # initializers + self._x_initializer = is_initializer(x_initializer) + + # integral + if self.noise is None: + self.integral = odeint(method=method, f=self.derivative) + else: + self.integral = sdeint(method=method, f=self.derivative, g=self.noise) + + # variables + self.reset_state(self.mode) + + def derivative(self, V, t, I_ext): + return (-V + I_ext) / self.tau + + def reset_state(self, batch_size=None): + self.x = variable_(self._x_initializer, self.varshape, batch_size) + if self.input_var: + self.input = variable_(bm.zeros, self.varshape, batch_size) + + def update(self, x=None): + t = share.load('t') + dt = share.load('dt') + if self.input_var: + if x is not None: + self.input += x + x = self.input.value + else: + x = 0. if x is None else x + self.x.value = self.integral(self.x.value, t, I_ext=x, dt=dt) + return self.x.value + + def clear_input(self): + if self.input_var: + self.input[:] = 0. + + +class LIF(NeuGroupNS): r"""Leaky integrate-and-fire neuron model. **Model Descriptions** @@ -266,7 +421,6 @@ def reset_state(self, batch_size=None): if self.ref_var: self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) - @not_pass_shared def update(self, x=None): t = share.load('t') dt = share.load('dt') @@ -325,7 +479,7 @@ def clear_input(self): self.input[:] = 0. -class ExpIF(NeuGroup): +class ExpIF(NeuGroupNS): r"""Exponential integrate-and-fire neuron model. **Model Descriptions** @@ -491,7 +645,6 @@ def derivative(self, V, t, I_ext): dvdt = (- (V - self.V_rest) + exp_v + self.R * I_ext) / self.tau return dvdt - @not_pass_shared def update(self, x=None): t = share.load('t') dt = share.load('dt') @@ -522,7 +675,7 @@ def clear_input(self): self.input[:] = 0. -class AdExIF(NeuGroup): +class AdExIF(NeuGroupNS): r"""Adaptive exponential integrate-and-fire neuron model. **Model Descriptions** @@ -678,7 +831,6 @@ def dw(self, w, t, V): def derivative(self): return JointEq([self.dV, self.dw]) - @not_pass_shared def update(self, x=None): t = share.load('t') dt = share.load('dt') @@ -706,7 +858,7 @@ def clear_input(self): self.input[:] = 0. -class QuaIF(NeuGroup): +class QuaIF(NeuGroupNS): r"""Quadratic Integrate-and-Fire neuron model. **Model Descriptions** @@ -837,7 +989,6 @@ def derivative(self, V, t, I_ext): dVdt = (self.c * (V - self.V_rest) * (V - self.V_c) + self.R * I_ext) / self.tau return dVdt - @not_pass_shared def update(self, x=None): t = share.load('t') dt = share.load('dt') @@ -867,7 +1018,7 @@ def clear_input(self): self.input[:] = 0. -class AdQuaIF(NeuGroup): +class AdQuaIF(NeuGroupNS): r"""Adaptive quadratic integrate-and-fire neuron model. **Model Descriptions** @@ -1018,7 +1169,6 @@ def dw(self, w, t, V): def derivative(self): return JointEq([self.dV, self.dw]) - @not_pass_shared def update(self, x=None): t = share.load('t') dt = share.load('dt') @@ -1040,7 +1190,7 @@ def clear_input(self): self.input[:] = 0. -class GIF(NeuGroup): +class GIF(NeuGroupNS): r"""Generalized Integrate-and-Fire model. **Model Descriptions** @@ -1220,7 +1370,6 @@ def dV(self, V, t, I1, I2, I_ext): def derivative(self): return JointEq([self.dI1, self.dI2, self.dVth, self.dV]) - @not_pass_shared def update(self, x=None): t = share.load('t') dt = share.load('dt') @@ -1258,7 +1407,7 @@ def clear_input(self): self.input[:] = 0. -class ALIFBellec2020(NeuGroup): +class ALIFBellec2020(NeuGroupNS): r"""Leaky Integrate-and-Fire model with SFA [1]_. This model is similar to the GLIF2 model in the Technical White Paper @@ -1371,7 +1520,6 @@ def reset_state(self, batch_size=None): self.t_last_spike = variable_(lambda s: bm.ones(s) * -1e7, self.varshape, batch_size) self.refractory = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, batch_size) - @not_pass_shared def update(self, x=None): t = share.load('t') dt = share.load('dt') @@ -1423,7 +1571,7 @@ def clear_input(self): self.input[:] = 0. -class Izhikevich(NeuGroup): +class Izhikevich(NeuGroupNS): r"""The Izhikevich neuron model. **Model Descriptions** @@ -1565,7 +1713,6 @@ def du(self, u, t, V): dudt = self.a * (self.b * V - u) return dudt - @not_pass_shared def update(self, x=None): t = share.load('t') dt = share.load('dt') @@ -1623,7 +1770,7 @@ def clear_input(self): self.input[:] = 0. -class HindmarshRose(NeuGroup): +class HindmarshRose(NeuGroupNS): r"""Hindmarsh-Rose neuron model. **Model Descriptions** @@ -1805,7 +1952,6 @@ def dz(self, z, t, V): def derivative(self): return JointEq([self.dV, self.dy, self.dz]) - @not_pass_shared def update(self, x=None): t = share.load('t') dt = share.load('dt') @@ -1830,7 +1976,7 @@ def clear_input(self): self.input[:] = 0. -class FHN(NeuGroup): +class FHN(NeuGroupNS): r"""FitzHugh-Nagumo neuron model. **Model Descriptions** @@ -1978,7 +2124,6 @@ def dw(self, w, t, V): def derivative(self): return JointEq([self.dV, self.dw]) - @not_pass_shared def update(self, x=None): t = share.load('t') dt = share.load('dt') @@ -1999,7 +2144,7 @@ def clear_input(self): self.input[:] = 0. -class LIF_SFA_Bellec2020(NeuGroup): +class LIF_SFA_Bellec2020(NeuGroupNS): r"""Leaky Integrate-and-Fire model with SFA [1]_. This model is similar to the GLIF2 model in the Technical White Paper @@ -2096,7 +2241,6 @@ def reset_state(self, batch_size=None): if self.tau_ref is not None: self.t_last_spike = variable_(OneInit(-1e7), self.varshape, batch_size) - @not_pass_shared def update(self, x=None): t = share.load('t') dt = share.load('dt') From 9f914e85933f7c8b5c04f9b24ec761cc245261bf Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 12 Mar 2023 21:57:22 +0800 Subject: [PATCH 4/8] update models --- brainpy/_src/dyn/layers/normalization.py | 28 +-- brainpy/_src/dyn/runners.py | 2 +- .../_src/dyn/synapses_v2/abstract_synapses.py | 161 +++++++++++++++++- brainpy/_src/dyn/synapses_v2/syn_outs.py | 8 +- .../_src/dyn/synapses_v2/syn_plasticity.py | 6 +- brainpy/math/activations.py | 1 + brainpy/neurons.py | 2 + 7 files changed, 180 insertions(+), 28 deletions(-) diff --git a/brainpy/_src/dyn/layers/normalization.py b/brainpy/_src/dyn/layers/normalization.py index 79811ff93..52d0303c5 100644 --- a/brainpy/_src/dyn/layers/normalization.py +++ b/brainpy/_src/dyn/layers/normalization.py @@ -130,16 +130,18 @@ def update(self, x): x = bm.as_jax(x) if share.load('fit'): - mean = jnp.mean(x, self.axis) - mean_of_square = jnp.mean(_square(x), self.axis) - if self.axis_name is not None: - mean, mean_of_square = jnp.split(lax.pmean(jnp.concatenate([mean, mean_of_square]), - axis_name=self.axis_name, - axis_index_groups=self.axis_index_groups), - 2) - var = jnp.maximum(0., mean_of_square - _square(mean)) - self.running_mean.value = (self.momentum * self.running_mean + (1 - self.momentum) * mean) - self.running_var.value = (self.momentum * self.running_var + (1 - self.momentum) * var) + mean = jnp.mean(x, self.axis) + mean_of_square = jnp.mean(_square(x), self.axis) + if self.axis_name is not None: + mean, mean_of_square = jnp.split( + lax.pmean(jnp.concatenate([mean, mean_of_square]), + axis_name=self.axis_name, + axis_index_groups=self.axis_index_groups), + 2 + ) + var = jnp.maximum(0., mean_of_square - _square(mean)) + self.running_mean.value = (self.momentum * self.running_mean + (1 - self.momentum) * mean) + self.running_var.value = (self.momentum * self.running_var + (1 - self.momentum) * var) else: mean = self.running_mean.value var = self.running_var.value @@ -488,7 +490,7 @@ def __init__( self.bias = bm.TrainVar(parameter(self.bias_initializer, self.normalized_shape)) self.scale = bm.TrainVar(parameter(self.scale_initializer, self.normalized_shape)) - def update(self,x): + def update(self, x): if x.shape[-len(self.normalized_shape):] != self.normalized_shape: raise ValueError(f'Expect the input shape should be (..., {", ".join(self.normalized_shape)}), ' f'but we got {x.shape}') @@ -629,6 +631,8 @@ def __init__( scale_initializer=scale_initializer, mode=mode, name=name) + + BatchNorm1D = BatchNorm1d BatchNorm2D = BatchNorm2d -BatchNorm3D = BatchNorm3d \ No newline at end of file +BatchNorm3D = BatchNorm3d diff --git a/brainpy/_src/dyn/runners.py b/brainpy/_src/dyn/runners.py index 30fe76dc8..c28d20d2f 100644 --- a/brainpy/_src/dyn/runners.py +++ b/brainpy/_src/dyn/runners.py @@ -597,7 +597,7 @@ def _get_input_time_step(self, duration=None, xs=None) -> int: if duration is not None: return int(duration / self.dt) if xs is not None: - if isinstance(xs, (bm.Array, jnp.ndarray)): + if isinstance(xs, (bm.Array, jax.Array, np.ndarray)): return xs.shape[0] if self.data_first_axis == 'T' else xs.shape[1] else: leaves, _ = tree_flatten(xs, is_leaf=lambda x: isinstance(x, bm.Array)) diff --git a/brainpy/_src/dyn/synapses_v2/abstract_synapses.py b/brainpy/_src/dyn/synapses_v2/abstract_synapses.py index a28efa1a4..8a9710043 100644 --- a/brainpy/_src/dyn/synapses_v2/abstract_synapses.py +++ b/brainpy/_src/dyn/synapses_v2/abstract_synapses.py @@ -8,14 +8,14 @@ from brainpy._src import tools from brainpy._src.connect import TwoEndConnector, All2All, One2One from brainpy._src.dyn.context import share -from brainpy._src.dyn.synapses_v2.base import SynConn, SynOut, SynSTP +from brainpy._src.dyn.synapses_v2.base import SynConnNS, SynOutNS, SynSTPNS from brainpy._src.initialize import Initializer, variable_ -from brainpy._src.integrators import odeint +from brainpy._src.integrators import odeint, JointEq from brainpy.check import is_float from brainpy.types import ArrayType -class Exponential(SynConn): +class Exponential(SynConnNS): r"""Exponential decay synapse model. **Model Descriptions** @@ -72,8 +72,8 @@ class Exponential(SynConn): def __init__( self, conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], - out: Optional[SynOut] = None, - stp: Optional[SynSTP] = None, + out: Optional[SynOutNS] = None, + stp: Optional[SynSTPNS] = None, comp_method: str = 'sparse', g_max: Union[float, ArrayType, Initializer, Callable] = 1., tau: Union[float, ArrayType] = 8.0, @@ -111,7 +111,7 @@ def update(self, pre_spike, post_v=None): if self.stp is not None: syn_value = self.stp(pre_spike) * pre_spike else: - syn_value = bm.asarray(pre_spike, dtype=bm.float_) + syn_value = pre_spike # post values if isinstance(self.conn, All2All): @@ -131,7 +131,6 @@ def update(self, pre_spike, post_v=None): transpose=True) if isinstance(self.mode, bm.BatchingMode): f = vmap(f) - post_vs = f(pre_spike) else: f = lambda s: bl.sparse_ops.cusparse_csr_matvec(self.g_max, self.conn_mask[0], @@ -141,7 +140,7 @@ def update(self, pre_spike, post_v=None): transpose=True) if isinstance(self.mode, bm.BatchingMode): f = vmap(f) - post_vs = f(syn_value) + post_vs = f(pre_spike) else: post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask) @@ -153,3 +152,149 @@ def update(self, pre_spike, post_v=None): return self.out(self.g.value, post_v) else: return self.g.value + + +class DualExponential(SynConnNS): + r"""Dual exponential synapse model. + + **Model Descriptions** + + The dual exponential synapse model [1]_, also named as *difference of two exponentials* model, + is given by: + + .. math:: + + g_{\mathrm{syn}}(t)=g_{\mathrm{max}} \frac{\tau_{1} \tau_{2}}{ + \tau_{1}-\tau_{2}}\left(\exp \left(-\frac{t-t_{0}}{\tau_{1}}\right) + -\exp \left(-\frac{t-t_{0}}{\tau_{2}}\right)\right) + + where :math:`\tau_1` is the time constant of the decay phase, :math:`\tau_2` + is the time constant of the rise phase, :math:`t_0` is the time of the pre-synaptic + spike, :math:`g_{\mathrm{max}}` is the maximal conductance. + + However, in practice, this formula is hard to implement. The equivalent solution is + two coupled linear differential equations [2]_: + + .. math:: + + \begin{aligned} + &g_{\mathrm{syn}}(t)=g_{\mathrm{max}} g * \mathrm{STP} \\ + &\frac{d g}{d t}=-\frac{g}{\tau_{\mathrm{decay}}}+h \\ + &\frac{d h}{d t}=-\frac{h}{\tau_{\text {rise }}}+ \delta\left(t_{0}-t\right), + \end{aligned} + + where :math:`\mathrm{STP}` is used to model the short-term plasticity effect of synapses. + + Parameters + ---------- + conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector + The synaptic connections. + comp_method: str + The connection type used for model speed optimization. It can be + `sparse` and `dense`. The default is `sparse`. + tau_decay: float, ArrayArray, ndarray + The time constant of the synaptic decay phase. [ms] + tau_rise: float, ArrayArray, ndarray + The time constant of the synaptic rise phase. [ms] + g_max: float, ArrayType, Initializer, Callable + The synaptic strength (the maximum conductance). Default is 1. + name: str + The name of this synaptic projection. + method: str + The numerical integration methods. + + References + ---------- + .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. + "The Synapse." Principles of Computational Modelling in Neuroscience. + Cambridge: Cambridge UP, 2011. 172-95. Print. + .. [2] Roth, A., & Van Rossum, M. C. W. (2009). Modeling Synapses. Computational + Modeling Methods for Neuroscientists. + + """ + + def __init__( + self, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], + out: Optional[SynOutNS] = None, + stp: Optional[SynSTPNS] = None, + comp_method: str = 'dense', + g_max: Union[float, ArrayType, Initializer, Callable] = 1., + tau_decay: Union[float, ArrayType] = 10.0, + tau_rise: Union[float, ArrayType] = 1., + method: str = 'exp_auto', + name: str = None, + mode: bm.Mode = None, + ): + super(DualExponential, self).__init__(conn=conn, + out=out, + stp=stp, + name=name, + mode=mode) + # parameters + self.comp_method = comp_method + self.tau_rise = is_float(tau_rise, allow_int=True, allow_none=False) + self.tau_decay = is_float(tau_decay, allow_int=True, allow_none=False) + + # connections and weights + self.g_max, self.conn_mask = self._init_weights(g_max, comp_method, data_if_sparse='csr') + + # function + self.integral = odeint(JointEq(self.dg, self.dh), method=method) + + # variables + self.reset_state(self.mode) + + def reset_state(self, batch_size=None): + self.h = variable_(bm.zeros, self.conn.pre_num, batch_size) + self.g = variable_(bm.zeros, self.conn.pre_num, batch_size) + if self.out is not None: + self.out.reset_state(batch_size) + if self.stp is not None: + self.stp.reset_state(batch_size) + + def dh(self, h, t): + return -h / self.tau_rise + + def dg(self, g, t, h): + return -g / self.tau_decay + h + + def update(self, pre_spike, post_v=None): + t = share.load('t') + dt = share.load('dt') + + # update synaptic variables + self.g.value, self.h.value = self.integral(self.g.value, self.h.value, t, dt=dt) + self.h += pre_spike + + # post values + syn_value = self.g.value + if self.stp is not None: + syn_value = self.stp(syn_value) + + if isinstance(self.conn, All2All): + post_vs = self._syn2post_with_all2all(syn_value, self.g_max, self.conn.include_self) + elif isinstance(self.conn, One2One): + post_vs = self._syn2post_with_one2one(syn_value, self.g_max) + else: + if self.comp_method == 'sparse': + bl = tools.import_brainpylib() + f = lambda s: bl.sparse_ops.cusparse_csr_matvec( + self.g_max, + self.conn_mask[0], + self.conn_mask[1], + s, + shape=(self.conn.pre_num, self.conn.post_num), + transpose=True + ) + if isinstance(self.mode, bm.BatchingMode): + f = vmap(f) + post_vs = f(syn_value) + else: + post_vs = self._syn2post_with_dense(syn_value, self.g_max, self.conn_mask) + + # outputs + if self.out is not None: + return self.out(post_vs, post_v) + else: + return post_vs diff --git a/brainpy/_src/dyn/synapses_v2/syn_outs.py b/brainpy/_src/dyn/synapses_v2/syn_outs.py index 435851ede..6324c82d4 100644 --- a/brainpy/_src/dyn/synapses_v2/syn_outs.py +++ b/brainpy/_src/dyn/synapses_v2/syn_outs.py @@ -2,7 +2,7 @@ from typing import Union -from brainpy._src.dyn.synapses_v2.base import SynOut +from brainpy._src.dyn.synapses_v2.base import SynOutNS from brainpy.math import exp from brainpy.types import ArrayType @@ -13,7 +13,7 @@ ] -class COBA(SynOut): +class COBA(SynOutNS): r"""Conductance-based synaptic output. Given the synaptic conductance, the model output the post-synaptic current with @@ -42,7 +42,7 @@ def update(self, post_g, post_v): return post_g * (self.E - post_v) -class CUBA(SynOut): +class CUBA(SynOutNS): r"""Current-based synaptic output. Given the conductance, this model outputs the post-synaptic current with a identity function: @@ -69,7 +69,7 @@ def update(self, g, post_V): return g -class MgBlock(SynOut): +class MgBlock(SynOutNS): r"""Synaptic output based on Magnesium blocking. Given the synaptic conductance, the model output the post-synaptic current with diff --git a/brainpy/_src/dyn/synapses_v2/syn_plasticity.py b/brainpy/_src/dyn/synapses_v2/syn_plasticity.py index e011cc8a1..2fb6f3619 100644 --- a/brainpy/_src/dyn/synapses_v2/syn_plasticity.py +++ b/brainpy/_src/dyn/synapses_v2/syn_plasticity.py @@ -6,7 +6,7 @@ from brainpy._src.dyn.context import share from brainpy import math as bm, tools -from brainpy._src.dyn.synapses_v2.base import SynSTP +from brainpy._src.dyn.synapses_v2.base import SynSTPNS from brainpy._src.initialize import variable_, OneInit, parameter from brainpy._src.integrators import odeint, JointEq from brainpy.types import ArrayType, Shape @@ -17,7 +17,7 @@ ] -class STD(SynSTP): +class STD(SynSTPNS): r"""Synaptic output with short-term depression. This model filters the synaptic current by the following equation: @@ -83,7 +83,7 @@ def update(self, pre_spike): return self.x.value -class STP(SynSTP): +class STP(SynSTPNS): r"""Synaptic output with short-term plasticity. This model filters the synaptic currents according to two variables: :math:`u` and :math:`x`. diff --git a/brainpy/math/activations.py b/brainpy/math/activations.py index 0096090f5..b4a1db8e9 100644 --- a/brainpy/math/activations.py +++ b/brainpy/math/activations.py @@ -25,3 +25,4 @@ selu as selu, identity as identity, ) +from .compat_numpy import tanh diff --git a/brainpy/neurons.py b/brainpy/neurons.py index fc084025e..ddc784bd4 100644 --- a/brainpy/neurons.py +++ b/brainpy/neurons.py @@ -25,6 +25,8 @@ ) from brainpy._src.dyn.neurons.reduced_models import ( + Leaky as Leaky, + Integrator as Integrator, LeakyIntegrator as LeakyIntegrator, LIF as LIF, ExpIF as ExpIF, From f1c505456e73084a1f1cee8415e3d4e587c68b2a Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 12 Mar 2023 21:57:39 +0800 Subject: [PATCH 5/8] update Sequential repr --- brainpy/_src/tools/codes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainpy/_src/tools/codes.py b/brainpy/_src/tools/codes.py index adad57764..01debfb20 100644 --- a/brainpy/_src/tools/codes.py +++ b/brainpy/_src/tools/codes.py @@ -32,7 +32,7 @@ def repr_object(x): if BrainPyObject is None: from brainpy.math import BrainPyObject if isinstance(x, BrainPyObject): - return x.name + return repr(x) elif callable(x): signature = inspect.signature(x) args = [f'{k}={v.default}' for k, v in signature.parameters.items() From 7fb57deb3873859d42eaf0e2c1a2aec37e511c03 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 12 Mar 2023 21:57:51 +0800 Subject: [PATCH 6/8] fix `LoopOverTime` bug --- brainpy/_src/dyn/transform.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/brainpy/_src/dyn/transform.py b/brainpy/_src/dyn/transform.py index 7b825762d..fdd93a1ed 100644 --- a/brainpy/_src/dyn/transform.py +++ b/brainpy/_src/dyn/transform.py @@ -278,8 +278,10 @@ def __call__( else: shared = tools.DotDict() - shared['t'] = jnp.arange(0, self.dt * length[0], self.dt) + self.t0.value - shared['i'] = jnp.arange(0, length[0]) + self.i0.value + if self.t0 is not None: + shared['t'] = jnp.arange(0, self.dt * length[0], self.dt) + self.t0.value + if self.i0 is not None: + shared['i'] = jnp.arange(0, length[0]) + self.i0.value assert not self.no_state results = bm.for_loop(functools.partial(self._run, self.shared_arg), From 3c9dd556254ead38e477268d9b9dbf61908cda52 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 12 Mar 2023 23:29:56 +0800 Subject: [PATCH 7/8] update experimental synapse models --- brainpy/__init__.py | 11 +- .../_src/dyn/synapses_v2/abstract_synapses.py | 103 ++++++++++++++++++ brainpy/_src/dyn/synapses_v2/base.py | 14 +-- brainpy/_src/dyn/synapses_v2/others.py | 86 +++++++++++++++ brainpy/experimental.py | 13 ++- 5 files changed, 210 insertions(+), 17 deletions(-) create mode 100644 brainpy/_src/dyn/synapses_v2/others.py diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 9dc05d28e..783ed6bb9 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -__version__ = "2.3.6" +__version__ = "2.3.7" # fundamental supporting modules @@ -61,13 +61,11 @@ experimental, ) from brainpy._src.dyn.base import not_pass_shared -from brainpy._src.dyn.base import (DynamicalSystem, - DynamicalSystemNS, +from brainpy._src.dyn.base import (DynamicalSystem as DynamicalSystem, Container as Container, Sequential as Sequential, Network as Network, NeuGroup as NeuGroup, - NeuGroupNS as NeuGroupNS, SynConn as SynConn, SynOut as SynOut, SynSTP as SynSTP, @@ -75,6 +73,11 @@ TwoEndConn as TwoEndConn, CondNeuGroup as CondNeuGroup, Channel as Channel) +from brainpy._src.dyn.base import (DynamicalSystemNS as DynamicalSystemNS, + NeuGroupNS as NeuGroupNS) +from brainpy._src.dyn.synapses_v2.base import (SynOutNS as SynOutNS, + SynSTPNS as SynSTPNS, + SynConnNS as SynConnNS, ) from brainpy._src.dyn.transform import (LoopOverTime as LoopOverTime,) from brainpy._src.dyn.runners import (DSRunner as DSRunner) # runner from brainpy._src.dyn.context import share, Delay diff --git a/brainpy/_src/dyn/synapses_v2/abstract_synapses.py b/brainpy/_src/dyn/synapses_v2/abstract_synapses.py index 8a9710043..39fdaf937 100644 --- a/brainpy/_src/dyn/synapses_v2/abstract_synapses.py +++ b/brainpy/_src/dyn/synapses_v2/abstract_synapses.py @@ -298,3 +298,106 @@ def update(self, pre_spike, post_v=None): return self.out(post_vs, post_v) else: return post_vs + + +class Alpha(DualExponential): + r"""Alpha synapse model. + + **Model Descriptions** + + The analytical expression of alpha synapse is given by: + + .. math:: + + g_{syn}(t)= g_{max} \frac{t-t_{s}}{\tau} \exp \left(-\frac{t-t_{s}}{\tau}\right). + + While, this equation is hard to implement. So, let's try to convert it into the + differential forms: + + .. math:: + + \begin{aligned} + &g_{\mathrm{syn}}(t)= g_{\mathrm{max}} g \\ + &\frac{d g}{d t}=-\frac{g}{\tau}+h \\ + &\frac{d h}{d t}=-\frac{h}{\tau}+\delta\left(t_{0}-t\right) + \end{aligned} + + **Model Examples** + + .. plot:: + :include-source: True + + >>> import brainpy as bp + >>> from brainpy import neurons, synapses, synouts + >>> import matplotlib.pyplot as plt + >>> + >>> neu1 = neurons.LIF(1) + >>> neu2 = neurons.LIF(1) + >>> syn1 = synapses.Alpha(neu1, neu2, bp.connect.All2All(), output=synouts.CUBA()) + >>> net = bp.Network(pre=neu1, syn=syn1, post=neu2) + >>> + >>> runner = bp.DSRunner(net, inputs=[('pre.input', 25.)], monitors=['pre.V', 'post.V', 'syn.g', 'syn.h']) + >>> runner.run(150.) + >>> + >>> fig, gs = bp.visualize.get_figure(2, 1, 3, 8) + >>> fig.add_subplot(gs[0, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['pre.V'], label='pre-V') + >>> plt.plot(runner.mon.ts, runner.mon['post.V'], label='post-V') + >>> plt.legend() + >>> fig.add_subplot(gs[1, 0]) + >>> plt.plot(runner.mon.ts, runner.mon['syn.g'], label='g') + >>> plt.plot(runner.mon.ts, runner.mon['syn.h'], label='h') + >>> plt.legend() + >>> plt.show() + + Parameters + ---------- + conn: optional, ArrayType, dict of (str, ndarray), TwoEndConnector + The synaptic connections. + comp_method: str + The connection type used for model speed optimization. It can be + `sparse` and `dense`. The default is `sparse`. + delay_step: int, ArrayType, Initializer, Callable + The delay length. It should be the value of :math:`\mathrm{delay\_time / dt}`. + tau_decay: float, ArrayType + The time constant of the synaptic decay phase. [ms] + g_max: float, ArrayType, Initializer, Callable + The synaptic strength (the maximum conductance). Default is 1. + name: str + The name of this synaptic projection. + method: str + The numerical integration methods. + + References + ---------- + + .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. + "The Synapse." Principles of Computational Modelling in Neuroscience. + Cambridge: Cambridge UP, 2011. 172-95. Print. + """ + + def __init__( + self, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], + out: Optional[SynOutNS] = None, + stp: Optional[SynSTPNS] = None, + comp_method: str = 'dense', + g_max: Union[float, ArrayType, Initializer, Callable] = 1., + tau_decay: Union[float, ArrayType] = 10.0, + method: str = 'exp_auto', + + # other parameters + name: str = None, + mode: bm.Mode = None, + ): + super().__init__(conn=conn, + comp_method=comp_method, + g_max=g_max, + tau_decay=tau_decay, + tau_rise=tau_decay, + method=method, + out=out, + stp=stp, + name=name, + mode=mode) + diff --git a/brainpy/_src/dyn/synapses_v2/base.py b/brainpy/_src/dyn/synapses_v2/base.py index cc1a36e8d..07ad87c98 100644 --- a/brainpy/_src/dyn/synapses_v2/base.py +++ b/brainpy/_src/dyn/synapses_v2/base.py @@ -10,12 +10,12 @@ from brainpy.types import ArrayType -class SynConn(DynamicalSystemNS): +class SynConnNS(DynamicalSystemNS): def __init__( self, conn: TwoEndConnector, - out: Optional['SynOut'] = None, - stp: Optional['SynSTP'] = None, + out: Optional['SynOutNS'] = None, + stp: Optional['SynSTPNS'] = None, name: str = None, mode: bm.Mode = None, ): @@ -28,8 +28,8 @@ def __init__( self.post_size = conn.post_size self.pre_num = conn.pre_num self.post_num = conn.post_num - assert out is None or isinstance(out, SynOut) - assert stp is None or isinstance(stp, SynSTP) + assert out is None or isinstance(out, SynOutNS) + assert stp is None or isinstance(stp, SynSTPNS) self.out = out self.stp = stp @@ -118,7 +118,7 @@ def _syn2post_with_dense(self, syn_value, syn_weight, conn_mat): return post_vs -class SynOut(DynamicalSystemNS): +class SynOutNS(DynamicalSystemNS): def update(self, post_g, post_v): raise NotImplementedError @@ -126,7 +126,7 @@ def reset_state(self, batch_size: Optional[int] = None): pass -class SynSTP(DynamicalSystemNS): +class SynSTPNS(DynamicalSystemNS): """Base class for synaptic short-term plasticity.""" def update(self, pre_spike): diff --git a/brainpy/_src/dyn/synapses_v2/others.py b/brainpy/_src/dyn/synapses_v2/others.py new file mode 100644 index 000000000..e21d9f881 --- /dev/null +++ b/brainpy/_src/dyn/synapses_v2/others.py @@ -0,0 +1,86 @@ + +from typing import Union, Optional + +import brainpy.math as bm +from brainpy._src.dyn.base import DynamicalSystemNS +from brainpy._src.dyn.context import share +from brainpy.check import is_float, is_integer + + +class PoissonInput(DynamicalSystemNS): + """Poisson Input. + + Adds independent Poisson input to a target variable. For large + numbers of inputs, this is much more efficient than creating a + `PoissonGroup`. The synaptic events are generated randomly during the + simulation and are not preloaded and stored in memory. All the inputs must + target the same variable, have the same frequency and same synaptic weight. + All neurons in the target variable receive independent realizations of + Poisson spike trains. + + Parameters + ---------- + num_input: int + The number of inputs. + freq: float + The frequency of each of the inputs. Must be a scalar. + weight: float + The synaptic weight. Must be a scalar. + """ + + def __init__( + self, + target_shape, + num_input: int, + freq: Union[int, float], + weight: Union[int, float], + seed: Optional[int] = None, + mode: bm.Mode = None, + name: str = None + ): + super(PoissonInput, self).__init__(name=name, mode=mode) + + # check data + is_integer(num_input, 'num_input', min_bound=1) + is_float(freq, 'freq', min_bound=0., allow_int=True) + is_float(weight, 'weight', allow_int=True) + assert self.mode.is_parent_of(bm.NonBatchingMode, bm.BatchingMode) + + # parameters + self.target_shape = target_shape + self.num_input = num_input + self.freq = freq + self.weight = weight + self.seed = seed + self.rng = bm.random.default_rng(seed) + + def update(self): + p = self.freq * share.dt / 1e3 + a = self.num_input * p + b = self.num_input * (1 - p) + if isinstance(share.dt, (int, float)): # dt is not in tracing + if (a > 5) and (b > 5): + inp = self.rng.normal(a, b * p, self.target_shape) + else: + inp = self.rng.binomial(self.num_input, p, self.target_shape) + + else: # dt is in tracing + inp = bm.cond((a > 5) * (b > 5), + lambda _: self.rng.normal(a, b * p, self.target_shape), + lambda _: self.rng.binomial(self.num_input, p, self.target_shape), + None, + dyn_vars=self.rng) + return inp * self.weight + + def __repr__(self): + names = self.__class__.__name__ + return f'{names}(shape={self.target_shape}, num_input={self.num_input}, freq={self.freq}, weight={self.weight})' + + def reset_state(self, batch_size=None): + pass + + def reset(self, batch_size=None): + self.rng.seed(self.seed) + self.reset_state(batch_size) + + diff --git a/brainpy/experimental.py b/brainpy/experimental.py index 8dab17552..7d182a4a2 100644 --- a/brainpy/experimental.py +++ b/brainpy/experimental.py @@ -1,9 +1,4 @@ -from brainpy._src.dyn.synapses_v2.base import ( - SynConn as SynConn, - SynOut as SynOut, - SynSTP as SynSTP, -) from brainpy._src.dyn.synapses_v2.syn_plasticity import ( STD as STD, STP as STP, @@ -13,5 +8,11 @@ COBA as COBA, ) from brainpy._src.dyn.synapses_v2.abstract_synapses import ( - Exponential as Exponential, + Exponential, + DualExponential, + Alpha, +) +from brainpy._src.dyn.synapses_v2.others import ( + PoissonInput, ) + From d2bd3053cd2ba101337f1bc297720e9b34c6898f Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 12 Mar 2023 23:34:03 +0800 Subject: [PATCH 8/8] fix bugs --- brainpy/_src/dyn/rates/populations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainpy/_src/dyn/rates/populations.py b/brainpy/_src/dyn/rates/populations.py index 647308d9f..9f32df906 100644 --- a/brainpy/_src/dyn/rates/populations.py +++ b/brainpy/_src/dyn/rates/populations.py @@ -3,7 +3,7 @@ from typing import Union, Callable from brainpy import math as bm -from brainpy import share +from brainpy._src.dyn.context import share from brainpy._src.dyn.base import NeuGroupNS from brainpy._src.dyn.neurons.noise_groups import OUProcess from brainpy._src.initialize import (Initializer,