Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compatibility updates #442

Merged
merged 7 commits into from
Aug 13, 2023
Merged
2 changes: 1 addition & 1 deletion brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.4.3.post3"
__version__ = "2.4.3.post4"

# fundamental supporting modules
from brainpy import errors, check, tools
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def set_dt(self, dt: Union[int, float]):
self._arguments['dt'] = dt

def load(self, key, value: Any = None):
"""Get the shared data by the ``key``.
"""Load the shared data by the ``key``.

Args:
key (str): the key to indicate the data.
Expand Down
42 changes: 41 additions & 1 deletion brainpy/_src/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from brainpy._src.dynsys import DynamicalSystem
from brainpy._src.initialize import variable_
from brainpy._src.math.delayvars import ROTATE_UPDATE, CONCAT_UPDATE
from brainpy._src.mixin import ParamDesc
from brainpy._src.mixin import ParamDesc, ReturnInfo
from brainpy.check import jit_error


Expand All @@ -28,6 +28,9 @@
]


delay_identifier = '_*_delay_*_'


class Delay(DynamicalSystem, ParamDesc):
"""Base class for delay variables.

Expand Down Expand Up @@ -474,3 +477,40 @@ def update(self):
return self.delay.at(self.name, *self.indices)


def init_delay_by_return(info: Union[bm.Variable, ReturnInfo]) -> Delay:
if isinstance(info, bm.Variable):
return VarDelay(info)

elif isinstance(info, ReturnInfo):
# batch size
if isinstance(info.batch_or_mode, int):
shape = (info.batch_or_mode,) + tuple(info.size)
batch_axis = 0
elif isinstance(info.batch_or_mode, bm.NonBatchingMode):
shape = tuple(info.size)
batch_axis = None
elif isinstance(info.batch_or_mode, bm.BatchingMode):
shape = (info.batch_or_mode.batch_size,) + tuple(info.size)
batch_axis = 0
else:
shape = tuple(info.size)
batch_axis = None

# init
if isinstance(info.data, Callable):
init = info.data(shape)
elif isinstance(info.data, (bm.Array, jax.Array)):
init = info.data
else:
raise TypeError
assert init.shape == shape

# axis names
if info.axis_names is not None:
assert init.ndim == len(info.axis_names)

# variable
target = bm.Variable(init, batch_axis=batch_axis, axis_names=info.axis_names)
return DataDelay(target, data_init=info.data)
else:
raise TypeError
8 changes: 4 additions & 4 deletions brainpy/_src/dyn/ions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,15 @@ def check_hierarchy(self, roots, leaf):
raise TypeError(f'Type does not match. {leaf} requires a master with type '
f'of {leaf.master_type}, but the master type now is {roots}.')

def add_elem(self, **elements):
def add_elem(self, *elems, **elements):
"""Add new elements.

Args:
elements: children objects.
"""
self.check_hierarchies(self._ion_classes, **elements)
self.children.update(self.format_elements(IonChaDyn, **elements))
for key, elem in elements.items():
self.check_hierarchies(self._ion_classes, *elems, **elements)
self.children.update(self.format_elements(IonChaDyn, *elems, **elements))
for elem in tuple(elems) + tuple(elements.values()):
for ion_root in elem.master_type.__args__:
ion = self._get_imp(ion_root)
ion.add_external_current(elem.name, self._get_ion_fun(ion, elem))
Expand Down
Loading