From 9e6de4eacbcb61c8f7bf1101becaa84bd3d2e93a Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 10 Aug 2023 17:38:00 +0800 Subject: [PATCH 1/5] updates --- brainpy/__init__.py | 2 +- brainpy/_src/dyn/ions/base.py | 8 ++++---- brainpy/_src/dynsys.py | 14 +++++++++++++- brainpy/_src/integrators/ode/exponential.py | 4 ++-- brainpy/_src/math/object_transform/controls.py | 3 +-- brainpy/_src/mixin.py | 5 ++--- examples/dynamics_simulation/COBA.py | 6 +++++- examples/dynamics_simulation/COBA_parallel.py | 2 +- 8 files changed, 29 insertions(+), 15 deletions(-) diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 1c1c12a13..121d0c6ff 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -__version__ = "2.4.3.post3" +__version__ = "2.4.3.post4" # fundamental supporting modules from brainpy import errors, check, tools diff --git a/brainpy/_src/dyn/ions/base.py b/brainpy/_src/dyn/ions/base.py index 7b3f13e29..145c1ded0 100644 --- a/brainpy/_src/dyn/ions/base.py +++ b/brainpy/_src/dyn/ions/base.py @@ -82,15 +82,15 @@ def check_hierarchy(self, roots, leaf): raise TypeError(f'Type does not match. {leaf} requires a master with type ' f'of {leaf.master_type}, but the master type now is {roots}.') - def add_elem(self, **elements): + def add_elem(self, *elems, **elements): """Add new elements. Args: elements: children objects. """ - self.check_hierarchies(self._ion_classes, **elements) - self.children.update(self.format_elements(IonChaDyn, **elements)) - for key, elem in elements.items(): + self.check_hierarchies(self._ion_classes, *elems, **elements) + self.children.update(self.format_elements(IonChaDyn, *elems, **elements)) + for elem in tuple(elems) + tuple(elements.values()): for ion_root in elem.master_type.__args__: ion = self._get_imp(ion_root) ion.add_external_current(elem.name, self._get_ion_fun(ion, elem)) diff --git a/brainpy/_src/dynsys.py b/brainpy/_src/dynsys.py index de917ca31..69d6696bd 100644 --- a/brainpy/_src/dynsys.py +++ b/brainpy/_src/dynsys.py @@ -570,8 +570,20 @@ def __repr__(self): class Projection(DynamicalSystem): def reset_state(self, *args, **kwargs): - pass + nodes = tuple(self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values()) + if len(nodes): + for node in nodes: + node.reset_state(*args, **kwargs) + else: + raise ValueError('Do not implement the reset_state() function.') + def update(self, *args, **kwargs): + nodes = tuple(self.nodes(level=1, include_self=False).subset(DynamicalSystem).unique().values()) + if len(nodes): + for node in nodes: + node(*args, **kwargs) + else: + raise ValueError('Do not implement the update() function.') class Dynamic(DynamicalSystem, ReceiveInputProj): """Base class to model dynamics. diff --git a/brainpy/_src/integrators/ode/exponential.py b/brainpy/_src/integrators/ode/exponential.py index b2d142c0e..2e577e6ab 100644 --- a/brainpy/_src/integrators/ode/exponential.py +++ b/brainpy/_src/integrators/ode/exponential.py @@ -199,7 +199,7 @@ class ExponentialEuler(ODEIntegrator): >>> self.n.value = n >>> self.input[:] = 0. >>> - >>> run = bp.dyn.DSRunner(HH(1), inputs=('input', 2.), monitors=['V'], dt=0.05) + >>> run = bp.DSRunner(HH(1), inputs=('input', 2.), monitors=['V'], dt=0.05) >>> run(100) >>> bp.visualize.line_plot(run.mon.ts, run.mon.V, legend='V', show=True) @@ -269,7 +269,7 @@ class ExponentialEuler(ODEIntegrator): >>> self.n.value = n >>> self.input[:] = 0. >>> - >>> run = bp.dyn.DSRunner(HH(1), inputs=('input', 2.), monitors=['V'], dt=0.05) + >>> run = bp.DSRunner(HH(1), inputs=('input', 2.), monitors=['V'], dt=0.05) >>> run(100) >>> bp.visualize.line_plot(run.mon.ts, run.mon.V, legend='V', show=True) diff --git a/brainpy/_src/math/object_transform/controls.py b/brainpy/_src/math/object_transform/controls.py index 19efbf1af..a26c230cf 100644 --- a/brainpy/_src/math/object_transform/controls.py +++ b/brainpy/_src/math/object_transform/controls.py @@ -769,8 +769,7 @@ def for_loop( Please change your call from ``for_loop(fun, dyn_vars, operands)`` to ``for_loop(fun, operands, dyn_vars)``. - Simply speaking, all dynamically changed variables used in the body function should - be labeld in ``dyn_vars`` argument. All returns in body function will be gathered + All returns in body function will be gathered as the return of the whole loop. >>> import brainpy.math as bm diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py index b206f5da6..0b4ad1ca1 100644 --- a/brainpy/_src/mixin.py +++ b/brainpy/_src/mixin.py @@ -272,7 +272,7 @@ def format_elements(self, child_type: type, *children_as_tuple, **children_as_di res[k] = v return res - def add_elem(self, **elements): + def add_elem(self, *elems, **elements): """Add new elements. >>> obj = Container() @@ -281,8 +281,7 @@ def add_elem(self, **elements): Args: elements: children objects. """ - # self.check_hierarchies(type(self), **elements) - self.children.update(self.format_elements(object, **elements)) + self.children.update(self.format_elements(object, *elems, **elements)) class TreeNode(MixIn): diff --git a/examples/dynamics_simulation/COBA.py b/examples/dynamics_simulation/COBA.py index 3517864a0..af7511e19 100644 --- a/examples/dynamics_simulation/COBA.py +++ b/examples/dynamics_simulation/COBA.py @@ -174,13 +174,17 @@ def run3(): def run4(): - bm.set(dt=0.5) + bm.set(dt=0.5, x64=True) net = EICOBA_PostAlign(3200, 800, ltc=True) runner = bp.DSRunner(net, monitors={'E.spike': net.E.spike}) print(runner.run(100., eval_time=True)) bp.visualize.raster_plot(runner.mon.ts, runner.mon['E.spike'], show=True) + + + + if __name__ == '__main__': # run1() # run2() diff --git a/examples/dynamics_simulation/COBA_parallel.py b/examples/dynamics_simulation/COBA_parallel.py index fff6275ff..a0f10de09 100644 --- a/examples/dynamics_simulation/COBA_parallel.py +++ b/examples/dynamics_simulation/COBA_parallel.py @@ -70,8 +70,8 @@ def run(indexes): with bm.sharding.device_mesh(jax.devices(), [bm.sharding.NEU_AXIS]): - # model = EINet1() model = EINet2() indices = bm.arange(1000) spks = run(indices) bp.visualize.raster_plot(indices, spks, show=True) + From 67890abc00a693a9786a934e41735d5e11a8d81b Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 10 Aug 2023 17:42:26 +0800 Subject: [PATCH 2/5] update type info in Projection Align --- brainpy/_src/dyn/projections/aligns.py | 47 +++++++++++++------------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/brainpy/_src/dyn/projections/aligns.py b/brainpy/_src/dyn/projections/aligns.py index 9607a6200..6b2db60de 100644 --- a/brainpy/_src/dyn/projections/aligns.py +++ b/brainpy/_src/dyn/projections/aligns.py @@ -6,8 +6,7 @@ from brainpy._src.delay import Delay, VarDelay, DataDelay, DelayAccess from brainpy._src.dynsys import DynamicalSystem, Projection from brainpy._src.mixin import (JointType, ParamDescInit, ReturnInfo, - AutoDelaySupp, BindCondData, AlignPost, - ReceiveInputProj) + AutoDelaySupp, BindCondData, AlignPost) __all__ = [ 'VanillaProj', @@ -144,7 +143,7 @@ def __init__( self, comm: DynamicalSystem, out: JointType[DynamicalSystem, BindCondData], - post: JointType[DynamicalSystem, ReceiveInputProj], + post: DynamicalSystem, name: Optional[str] = None, mode: Optional[bm.Mode] = None, ): @@ -153,7 +152,7 @@ def __init__( # synaptic models check.is_instance(comm, DynamicalSystem) check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, JointType[DynamicalSystem, ReceiveInputProj]) + check.is_instance(post, DynamicalSystem) self.comm = comm # output initialization @@ -221,7 +220,7 @@ def __init__( comm: DynamicalSystem, syn: ParamDescInit[JointType[DynamicalSystem, AlignPost]], out: ParamDescInit[JointType[DynamicalSystem, BindCondData]], - post: JointType[DynamicalSystem, ReceiveInputProj], + post: DynamicalSystem, name: Optional[str] = None, mode: Optional[bm.Mode] = None, ): @@ -231,7 +230,7 @@ def __init__( check.is_instance(comm, DynamicalSystem) check.is_instance(syn, ParamDescInit[JointType[DynamicalSystem, AlignPost]]) check.is_instance(out, ParamDescInit[JointType[DynamicalSystem, BindCondData]]) - check.is_instance(post, JointType[DynamicalSystem, ReceiveInputProj]) + check.is_instance(post, DynamicalSystem) self.comm = comm # synapse and output initialization @@ -330,7 +329,7 @@ def __init__( comm: DynamicalSystem, syn: ParamDescInit[JointType[DynamicalSystem, AlignPost]], out: ParamDescInit[JointType[DynamicalSystem, BindCondData]], - post: JointType[DynamicalSystem, ReceiveInputProj], + post: DynamicalSystem, name: Optional[str] = None, mode: Optional[bm.Mode] = None, ): @@ -341,7 +340,7 @@ def __init__( check.is_instance(comm, DynamicalSystem) check.is_instance(syn, ParamDescInit[JointType[DynamicalSystem, AlignPost]]) check.is_instance(out, ParamDescInit[JointType[DynamicalSystem, BindCondData]]) - check.is_instance(post, JointType[DynamicalSystem, ReceiveInputProj]) + check.is_instance(post, DynamicalSystem) self.comm = comm # delay initialization @@ -422,7 +421,7 @@ def __init__( comm: DynamicalSystem, syn: JointType[DynamicalSystem, AlignPost], out: JointType[DynamicalSystem, BindCondData], - post: JointType[DynamicalSystem, ReceiveInputProj], + post: DynamicalSystem, name: Optional[str] = None, mode: Optional[bm.Mode] = None, ): @@ -432,7 +431,7 @@ def __init__( check.is_instance(comm, DynamicalSystem) check.is_instance(syn, JointType[DynamicalSystem, AlignPost]) check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, JointType[DynamicalSystem, ReceiveInputProj]) + check.is_instance(post, DynamicalSystem) self.comm = comm # synapse and output initialization @@ -523,7 +522,7 @@ def __init__( comm: DynamicalSystem, syn: JointType[DynamicalSystem, AlignPost], out: JointType[DynamicalSystem, BindCondData], - post: JointType[DynamicalSystem, ReceiveInputProj], + post: DynamicalSystem, name: Optional[str] = None, mode: Optional[bm.Mode] = None, ): @@ -534,7 +533,7 @@ def __init__( check.is_instance(comm, DynamicalSystem) check.is_instance(syn, JointType[DynamicalSystem, AlignPost]) check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, JointType[DynamicalSystem, ReceiveInputProj]) + check.is_instance(post, DynamicalSystem) self.comm = comm self.syn = syn @@ -634,7 +633,7 @@ def __init__( delay: Union[None, int, float], comm: DynamicalSystem, out: JointType[DynamicalSystem, BindCondData], - post: JointType[DynamicalSystem, ReceiveInputProj], + post: DynamicalSystem, name: Optional[str] = None, mode: Optional[bm.Mode] = None, ): @@ -645,7 +644,7 @@ def __init__( check.is_instance(syn, ParamDescInit[JointType[DynamicalSystem, AutoDelaySupp]]) check.is_instance(comm, DynamicalSystem) check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, JointType[DynamicalSystem, ReceiveInputProj]) + check.is_instance(post, DynamicalSystem) self.comm = comm # synapse and delay initialization @@ -744,10 +743,10 @@ def __init__( self, pre: JointType[DynamicalSystem, AutoDelaySupp], delay: Union[None, int, float], - syn: ParamDescInit[JointType[DynamicalSystem, AutoDelaySupp]], + syn: ParamDescInit[DynamicalSystem], comm: DynamicalSystem, out: JointType[DynamicalSystem, BindCondData], - post: JointType[DynamicalSystem, ReceiveInputProj], + post: DynamicalSystem, name: Optional[str] = None, mode: Optional[bm.Mode] = None, ): @@ -755,10 +754,10 @@ def __init__( # synaptic models check.is_instance(pre, JointType[DynamicalSystem, AutoDelaySupp]) - check.is_instance(syn, ParamDescInit[JointType[DynamicalSystem, AutoDelaySupp]]) + check.is_instance(syn, ParamDescInit[DynamicalSystem]) check.is_instance(comm, DynamicalSystem) check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, JointType[DynamicalSystem, ReceiveInputProj]) + check.is_instance(post, DynamicalSystem) self.comm = comm # delay initialization @@ -865,7 +864,7 @@ def __init__( delay: Union[None, int, float], comm: DynamicalSystem, out: JointType[DynamicalSystem, BindCondData], - post: JointType[DynamicalSystem, ReceiveInputProj], + post: DynamicalSystem, name: Optional[str] = None, mode: Optional[bm.Mode] = None, ): @@ -876,7 +875,7 @@ def __init__( check.is_instance(syn, JointType[DynamicalSystem, AutoDelaySupp]) check.is_instance(comm, DynamicalSystem) check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, JointType[DynamicalSystem, ReceiveInputProj]) + check.is_instance(post, DynamicalSystem) self.comm = comm # synapse and delay initialization @@ -970,10 +969,10 @@ def __init__( self, pre: JointType[DynamicalSystem, AutoDelaySupp], delay: Union[None, int, float], - syn: JointType[DynamicalSystem, AutoDelaySupp], + syn: DynamicalSystem, comm: DynamicalSystem, out: JointType[DynamicalSystem, BindCondData], - post: JointType[DynamicalSystem, ReceiveInputProj], + post: DynamicalSystem, name: Optional[str] = None, mode: Optional[bm.Mode] = None, ): @@ -981,10 +980,10 @@ def __init__( # synaptic models check.is_instance(pre, JointType[DynamicalSystem, AutoDelaySupp]) - check.is_instance(syn, JointType[DynamicalSystem, AutoDelaySupp]) + check.is_instance(syn, DynamicalSystem) check.is_instance(comm, DynamicalSystem) check.is_instance(out, JointType[DynamicalSystem, BindCondData]) - check.is_instance(post, JointType[DynamicalSystem, ReceiveInputProj]) + check.is_instance(post, DynamicalSystem) self.comm = comm self.syn = syn From 03349ad3bfee5ec093551afb1b4cb73c7fd0583f Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 10 Aug 2023 18:04:03 +0800 Subject: [PATCH 3/5] Deprecation and compatibility for the old `synapse.g_max` attribute --- brainpy/_src/dynold/synapses/base.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/brainpy/_src/dynold/synapses/base.py b/brainpy/_src/dynold/synapses/base.py index 53362219c..a6564d14d 100644 --- a/brainpy/_src/dynold/synapses/base.py +++ b/brainpy/_src/dynold/synapses/base.py @@ -1,3 +1,4 @@ +import warnings from typing import Union, Dict, Callable, Optional, Tuple import jax @@ -325,4 +326,25 @@ def update(self, pre_spike=None, stop_spike_gradient: bool = False): current = self.comm(self.syn(pre_spike)) return self.output(current) + @property + def g_max(self): + warnings.warn('".g_max" is deprecated. ' + 'Use ".comm.weight" instead.', + UserWarning) + return self.comm.weight + + @g_max.setter + def g_max(self, v): + warnings.warn('Updating ".g_max" is deprecated. ' + 'Updating ".comm.weight" instead.', + UserWarning) + self.comm.weight = v + + def reset_state(self, *args, **kwargs): + self.syn.reset_state(*args, **kwargs) + self.comm.reset_state(*args, **kwargs) + self.output.reset_state(*args, **kwargs) + if self.stp is not None: + self.stp.reset_state(*args, **kwargs) + From 6a38fdca0b8bbeea88b0bf1dff542f368a8c86ba Mon Sep 17 00:00:00 2001 From: chaoming Date: Thu, 10 Aug 2023 18:13:38 +0800 Subject: [PATCH 4/5] compatible with `brainpy.math.enable_x64(True/False)` --- brainpy/_src/math/environment.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/brainpy/_src/math/environment.py b/brainpy/_src/math/environment.py index 0f775da19..950d87933 100644 --- a/brainpy/_src/math/environment.py +++ b/brainpy/_src/math/environment.py @@ -6,6 +6,7 @@ import os import re import sys +import warnings from typing import Any, Callable, TypeVar, cast from jax import config, numpy as jnp, devices @@ -15,7 +16,6 @@ bm = None - __all__ = [ # context manage for environment setting 'environment', @@ -36,7 +36,6 @@ # default computation modes 'set_mode', 'get_mode', - # set jax environments 'enable_x64', 'disable_x64', 'set_platform', 'get_platform', @@ -53,7 +52,6 @@ ] - # See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators FuncType = Callable[..., Any] F = TypeVar('F', bound=FuncType) @@ -553,11 +551,23 @@ def get_mode() -> modes.Mode: return bm.mode -def enable_x64(): - config.update("jax_enable_x64", True) - set_int(jnp.int64) - set_float(jnp.float64) - set_complex(jnp.complex128) +def enable_x64(x64=None): + if x64 is None: + x64 = True + else: + warnings.warn( + '\n' + 'Instead of "brainpy.math.enable_x64(True)", use "brainpy.math.enable_x64()". \n' + 'Instead of "brainpy.math.enable_x64(False)", use "brainpy.math.disable_x64()". \n', + DeprecationWarning + ) + if x64: + config.update("jax_enable_x64", True) + set_int(jnp.int64) + set_float(jnp.float64) + set_complex(jnp.complex128) + else: + disable_x64() def disable_x64(): @@ -649,4 +659,3 @@ def enable_gpu_memory_preallocation(): """Disable pre-allocating the GPU memory.""" os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'true' os.environ.pop('XLA_PYTHON_CLIENT_ALLOCATOR') - From d113a000d11a2ebd3d167d50102c556e4fd75ea3 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 13 Aug 2023 20:37:56 +0800 Subject: [PATCH 5/5] [delay] new delay registration methods: `register_delay_at` and `get_delay_at` --- brainpy/_src/context.py | 2 +- brainpy/_src/delay.py | 42 +++++++++- brainpy/_src/dyn/projections/aligns.py | 77 ++++++------------- .../_src/dynold/synapses/abstract_models.py | 7 +- brainpy/_src/mixin.py | 74 ++++++++++++++---- brainpy/_src/tests/test_mixin.py | 20 +++++ 6 files changed, 145 insertions(+), 77 deletions(-) diff --git a/brainpy/_src/context.py b/brainpy/_src/context.py index 87724618a..6fca8a8d2 100644 --- a/brainpy/_src/context.py +++ b/brainpy/_src/context.py @@ -38,7 +38,7 @@ def set_dt(self, dt: Union[int, float]): self._arguments['dt'] = dt def load(self, key, value: Any = None): - """Get the shared data by the ``key``. + """Load the shared data by the ``key``. Args: key (str): the key to indicate the data. diff --git a/brainpy/_src/delay.py b/brainpy/_src/delay.py index c780bcd87..9b9e7bf01 100644 --- a/brainpy/_src/delay.py +++ b/brainpy/_src/delay.py @@ -16,7 +16,7 @@ from brainpy._src.dynsys import DynamicalSystem from brainpy._src.initialize import variable_ from brainpy._src.math.delayvars import ROTATE_UPDATE, CONCAT_UPDATE -from brainpy._src.mixin import ParamDesc +from brainpy._src.mixin import ParamDesc, ReturnInfo from brainpy.check import jit_error @@ -28,6 +28,9 @@ ] +delay_identifier = '_*_delay_*_' + + class Delay(DynamicalSystem, ParamDesc): """Base class for delay variables. @@ -474,3 +477,40 @@ def update(self): return self.delay.at(self.name, *self.indices) +def init_delay_by_return(info: Union[bm.Variable, ReturnInfo]) -> Delay: + if isinstance(info, bm.Variable): + return VarDelay(info) + + elif isinstance(info, ReturnInfo): + # batch size + if isinstance(info.batch_or_mode, int): + shape = (info.batch_or_mode,) + tuple(info.size) + batch_axis = 0 + elif isinstance(info.batch_or_mode, bm.NonBatchingMode): + shape = tuple(info.size) + batch_axis = None + elif isinstance(info.batch_or_mode, bm.BatchingMode): + shape = (info.batch_or_mode.batch_size,) + tuple(info.size) + batch_axis = 0 + else: + shape = tuple(info.size) + batch_axis = None + + # init + if isinstance(info.data, Callable): + init = info.data(shape) + elif isinstance(info.data, (bm.Array, jax.Array)): + init = info.data + else: + raise TypeError + assert init.shape == shape + + # axis names + if info.axis_names is not None: + assert init.ndim == len(info.axis_names) + + # variable + target = bm.Variable(init, batch_axis=batch_axis, axis_names=info.axis_names) + return DataDelay(target, data_init=info.data) + else: + raise TypeError diff --git a/brainpy/_src/dyn/projections/aligns.py b/brainpy/_src/dyn/projections/aligns.py index 6b2db60de..c53331459 100644 --- a/brainpy/_src/dyn/projections/aligns.py +++ b/brainpy/_src/dyn/projections/aligns.py @@ -3,7 +3,7 @@ import jax from brainpy import math as bm, check -from brainpy._src.delay import Delay, VarDelay, DataDelay, DelayAccess +from brainpy._src.delay import Delay, DelayAccess, delay_identifier, init_delay_by_return from brainpy._src.dynsys import DynamicalSystem, Projection from brainpy._src.mixin import (JointType, ParamDescInit, ReturnInfo, AutoDelaySupp, BindCondData, AlignPost) @@ -16,8 +16,6 @@ 'ProjAlignPre1', 'ProjAlignPre2', ] -_pre_delay_repr = '_*_align_pre_spk_delay_*_' - class _AlignPre(DynamicalSystem): def __init__(self, syn, delay=None): @@ -54,37 +52,6 @@ def update(self, *args, **kwargs): return self.syn(self.access()) -def _init_delay(info: Union[bm.Variable, ReturnInfo]) -> Delay: - if isinstance(info, bm.Variable): - return VarDelay(info) - elif isinstance(info, ReturnInfo): - if isinstance(info.batch_or_mode, int): - shape = (info.batch_or_mode,) + tuple(info.size) - batch_axis = 0 - elif isinstance(info.batch_or_mode, bm.NonBatchingMode): - shape = tuple(info.size) - batch_axis = None - elif isinstance(info.batch_or_mode, bm.BatchingMode): - shape = (info.batch_or_mode.batch_size,) + tuple(info.size) - batch_axis = 0 - else: - shape = tuple(info.size) - batch_axis = None - if isinstance(info.data, Callable): - init = info.data(shape) - elif isinstance(info.data, (bm.Array, jax.Array)): - init = info.data - else: - raise TypeError - assert init.shape == shape - if info.axis_names is not None: - assert init.ndim == len(info.axis_names) - target = bm.Variable(init, batch_axis=batch_axis, axis_names=info.axis_names) - return DataDelay(target, data_init=info.data) - else: - raise TypeError - - def _get_return(return_info): if isinstance(return_info, bm.Variable): return return_info.value @@ -344,12 +311,12 @@ def __init__( self.comm = comm # delay initialization - if not pre.has_aft_update(_pre_delay_repr): + if not pre.has_aft_update(delay_identifier): # pre should support "ProjAutoDelay" - delay_cls = _init_delay(pre.return_info()) + delay_cls = init_delay_by_return(pre.return_info()) # add to "after_updates" - pre.add_aft_update(_pre_delay_repr, delay_cls) - delay_cls: Delay = pre.get_aft_update(_pre_delay_repr) + pre.add_aft_update(delay_identifier, delay_cls) + delay_cls: Delay = pre.get_aft_update(delay_identifier) delay_cls.register_entry(self.name, delay) # synapse and output initialization @@ -366,7 +333,7 @@ def __init__( self.refs['out'] = post.get_bef_update(self._post_repr).out # invisible to ``self.node()`` def update(self): - x = self.refs['pre'].get_aft_update(_pre_delay_repr).at(self.name) + x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name) current = self.comm(x) self.refs['syn'].add_current(current) # synapse post current return current @@ -538,12 +505,12 @@ def __init__( self.syn = syn # delay initialization - if not pre.has_aft_update(_pre_delay_repr): + if not pre.has_aft_update(delay_identifier): # pre should support "ProjAutoDelay" - delay_cls = _init_delay(pre.return_info()) + delay_cls = init_delay_by_return(pre.return_info()) # add to "after_updates" - pre.add_aft_update(_pre_delay_repr, delay_cls) - delay_cls: Delay = pre.get_aft_update(_pre_delay_repr) + pre.add_aft_update(delay_identifier, delay_cls) + delay_cls: Delay = pre.get_aft_update(delay_identifier) delay_cls.register_entry(self.name, delay) # synapse and output initialization @@ -554,7 +521,7 @@ def __init__( self.refs['out'] = out def update(self): - x = self.refs['pre'].get_aft_update(_pre_delay_repr).at(self.name) + x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name) g = self.syn(self.comm(x)) self.refs['out'].bind_cond(g) # synapse post current return g @@ -652,7 +619,7 @@ def __init__( if not pre.has_aft_update(self._syn_id): # "syn_cls" needs an instance of "ProjAutoDelay" syn_cls: AutoDelaySupp = syn() - delay_cls = _init_delay(syn_cls.return_info()) + delay_cls = init_delay_by_return(syn_cls.return_info()) # add to "after_updates" pre.add_aft_update(self._syn_id, _AlignPre(syn_cls, delay_cls)) delay_cls: Delay = pre.get_aft_update(self._syn_id).delay @@ -761,10 +728,10 @@ def __init__( self.comm = comm # delay initialization - if not pre.has_aft_update(_pre_delay_repr): - delay_ins = _init_delay(pre.return_info()) - pre.add_aft_update(_pre_delay_repr, delay_ins) - delay_cls = pre.get_aft_update(_pre_delay_repr) + if not pre.has_aft_update(delay_identifier): + delay_ins = init_delay_by_return(pre.return_info()) + pre.add_aft_update(delay_identifier, delay_ins) + delay_cls = pre.get_aft_update(delay_identifier) # synapse initialization self._syn_id = f'Delay({str(delay)}) // {syn.identifier}' @@ -879,7 +846,7 @@ def __init__( self.comm = comm # synapse and delay initialization - delay_cls = _init_delay(syn.return_info()) + delay_cls = init_delay_by_return(syn.return_info()) delay_cls.register_entry(self.name, delay) pre.add_aft_update(self.name, _AlignPre(syn, delay_cls)) @@ -988,10 +955,10 @@ def __init__( self.syn = syn # delay initialization - if not pre.has_aft_update(_pre_delay_repr): - delay_ins = _init_delay(pre.return_info()) - pre.add_aft_update(_pre_delay_repr, delay_ins) - delay_cls = pre.get_aft_update(_pre_delay_repr) + if not pre.has_aft_update(delay_identifier): + delay_ins = init_delay_by_return(pre.return_info()) + pre.add_aft_update(delay_identifier, delay_ins) + delay_cls = pre.get_aft_update(delay_identifier) delay_cls.register_entry(self.name, delay) # output initialization @@ -999,7 +966,7 @@ def __init__( # references self.refs = dict(pre=pre, post=post, out=out) # invisible to ``self.nodes()`` - self.refs['delay'] = pre.get_aft_update(_pre_delay_repr) + self.refs['delay'] = pre.get_aft_update(delay_identifier) def update(self): spk = self.refs['delay'].at(self.name) diff --git a/brainpy/_src/dynold/synapses/abstract_models.py b/brainpy/_src/dynold/synapses/abstract_models.py index 2f52b0be9..cddb04d7c 100644 --- a/brainpy/_src/dynold/synapses/abstract_models.py +++ b/brainpy/_src/dynold/synapses/abstract_models.py @@ -6,14 +6,11 @@ import brainpy.math as bm from brainpy._src.connect import TwoEndConnector, All2All, One2One -from brainpy._src.context import share +from brainpy._src.dnn import linear from brainpy._src.dyn import synapses from brainpy._src.dyn.base import NeuDyn -from brainpy._src.dnn import linear from brainpy._src.dynold.synouts import MgBlock, CUBA -from brainpy._src.initialize import Initializer, variable_ -from brainpy._src.integrators.ode.generic import odeint -from brainpy._src.dyn.projections.aligns import _pre_delay_repr, _init_delay +from brainpy._src.initialize import Initializer from brainpy.types import ArrayType from .base import TwoEndConn, _SynSTP, _SynOut, _TwoEndConnAlignPre diff --git a/brainpy/_src/mixin.py b/brainpy/_src/mixin.py index 0b4ad1ca1..3662812b4 100644 --- a/brainpy/_src/mixin.py +++ b/brainpy/_src/mixin.py @@ -1,5 +1,6 @@ import numbers import sys +import warnings from dataclasses import dataclass from typing import Union, Dict, Callable, Sequence, Optional, TypeVar, Any from typing import (_SpecialForm, _type_check, _remove_dups_flatten) @@ -19,6 +20,8 @@ from typing import (_GenericAlias, _tp_cache) DynamicalSystem = None +delay_identifier, init_delay_by_return = None, None + __all__ = [ 'MixIn', @@ -323,6 +326,40 @@ def check_hierarchy(self, root, leaf): class DelayRegister(MixIn): local_delay_vars: bm.node_dict + def register_delay_at( + self, + name: str, + delay: Union[numbers.Number, ArrayType] = None, + ): + """Register relay at the given delay time. + + Args: + name: str. The identifier of the delay. + delay: The delay time. + """ + global delay_identifier, init_delay_by_return, DynamicalSystem + if init_delay_by_return is None: from brainpy._src.delay import init_delay_by_return + if delay_identifier is None: from brainpy._src.delay import delay_identifier + if DynamicalSystem is None: from brainpy._src.dynsys import DynamicalSystem + + assert isinstance(self, AutoDelaySupp), f'self must be an instance of {AutoDelaySupp.__name__}' + assert isinstance(self, DynamicalSystem), f'self must be an instance of {DynamicalSystem.__name__}' + if not self.has_aft_update(delay_identifier): + self.add_aft_update(delay_identifier, init_delay_by_return(self.return_info())) + delay_cls = self.get_aft_update(delay_identifier) + delay_cls.register_entry(name, delay) + + def get_delay_at(self, name): + """Get the delay at the given identifier (`name`). + + Args: + name: The identifier of the delay. + + Returns: + The delay data. + """ + return self.get_aft_update(delay_identifier).at(name) + def register_delay( self, identifier: str, @@ -332,22 +369,22 @@ def register_delay( ): """Register delay variable. - Parameters - ---------- - identifier: str - The delay variable name. - delay_step: Optional, int, ArrayType, callable, Initializer - The number of the steps of the delay. - delay_target: Variable - The target variable for delay. - initial_delay_data: float, int, ArrayType, callable, Initializer - The initializer for the delay data. + Args: + identifier: str. The delay access name. + delay_target: The target variable for delay. + delay_step: The delay time step. + initial_delay_data: The initializer for the delay data. - Returns - ------- - delay_step: int, ArrayType - The number of the delay steps. + Returns: + delay_step: The number of the delay steps. """ + warnings.warn('\n' + 'Starting from brainpy>=2.4.4, instead of ".register_delay()", ' + 'we recommend the user to first use ".register_delay_at()", ' + 'then use ".get_delay_at()" to access the delayed data. ' + '".register_delay()" will be removed after 2.5.0.', + UserWarning) + # delay steps if delay_step is None: delay_type = 'none' @@ -422,6 +459,13 @@ def get_delay_data( delay_data: ArrayType The delay data at the given time. """ + warnings.warn('\n' + 'Starting from brainpy>=2.4.4, instead of ".get_delay_data()", ' + 'we recommend the user to first use ".register_delay_at()", ' + 'then use ".get_delay_at()" to access the delayed data.' + '".get_delay_data()" will be removed after 2.5.0.', + UserWarning) + if delay_step is None: return global_delay_data[identifier][1].value @@ -630,7 +674,7 @@ def __getitem__(self, parameters): 'JointType', doc="""Joint type; JointType[X, Y] means both X and Y. - To define a union, use e.g. JointType[int, str]. + To define a joint, use e.g. JointType[int, str]. Details: diff --git a/brainpy/_src/tests/test_mixin.py b/brainpy/_src/tests/test_mixin.py index 1544a1f33..d02e56274 100644 --- a/brainpy/_src/tests/test_mixin.py +++ b/brainpy/_src/tests/test_mixin.py @@ -1,4 +1,5 @@ import brainpy as bp +import brainpy.math as bm import unittest @@ -28,3 +29,22 @@ def test2(self): self.assertTrue(not isinstance(bp.dyn.Expon(1), bp.mixin.ParamDescInit[T])) self.assertTrue(isinstance(bp.dyn.Expon.desc(1), bp.mixin.ParamDescInit[T])) + +class TestDelayRegister(unittest.TestCase): + def test11(self): + lif = bp.dyn.Lif(10) + with self.assertWarns(UserWarning): + lif.register_delay('pre.spike', 10, lif.spike) + + with self.assertWarns(UserWarning): + lif.get_delay_data('pre.spike', 10) + + def test2(self): + bp.share.save(i=0) + lif = bp.dyn.Lif(10) + lif.register_delay_at('a', 10.) + data = lif.get_delay_at('a') + self.assertTrue(bm.allclose(data, bm.zeros(10))) + + +