diff --git a/.gitignore b/.gitignore
index 0142f5132..1a494a3d2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -3,7 +3,7 @@ publishment.md
.vscode
io_test_tmp*
-brainpy/base/tests/io_test_tmp*
+brainpy/math/brainpy_object/tests/io_test_tmp*
development
@@ -217,3 +217,7 @@ cython_debug/
/docs/apis/simulation/generated/
!/brainpy/dyn/tests/data/
/examples/dynamics_simulation/data/
+/examples/training_snn_models/logs/T100_b64_lr0.001/
+/examples/training_snn_models/logs/
+/examples/training_snn_models/data/
+/docs/tutorial_advanced/data/
diff --git a/brainpy/__init__.py b/brainpy/__init__.py
index da1a218e1..bac4404d7 100644
--- a/brainpy/__init__.py
+++ b/brainpy/__init__.py
@@ -6,22 +6,8 @@
# fundamental modules
from . import errors, check, tools
-# "base" module
-from . import base
-from .base import (
- # base class
- Base,
- BrainPyObject,
-
- # collector
- Collector,
- ArrayCollector,
- TensorCollector,
-)
-
# math foundation
from . import math
-from . import modes
# toolboxes
from . import (
@@ -69,7 +55,7 @@
synouts, # synaptic output
synplast, # synaptic plasticity
- # base classes
+ # brainpy_object classes
DynamicalSystem,
Container,
Sequential,
@@ -113,9 +99,7 @@
# running
from . import running
-from .running import (
- Runner
-)
+from .running import (Runner)
# "visualization" module, will be removed soon
from .visualization import visualize
@@ -124,3 +108,15 @@
conn = connect
init = initialize
optim = optimizers
+
+from . import experimental
+
+
+# deprecated
+from . import base
+# use ``brainpy.math.*`` instead
+from brainpy.math.object_transform.base_object import (Base, BrainPyObject,)
+# use ``brainpy.math.*`` instead
+from brainpy.math.object_transform.collector import (Collector, ArrayCollector, TensorCollector,)
+# use ``brainpy.math.*`` instead
+from . import modes
diff --git a/brainpy/algorithms/offline.py b/brainpy/algorithms/offline.py
index 915b4257a..d622f7036 100644
--- a/brainpy/algorithms/offline.py
+++ b/brainpy/algorithms/offline.py
@@ -6,14 +6,13 @@
from jax.lax import while_loop
import brainpy.math as bm
-from brainpy.base import BrainPyObject
from brainpy.types import ArrayType
from .utils import (Sigmoid,
Regularization, L1Regularization, L1L2Regularization, L2Regularization,
polynomial_features, normalize)
__all__ = [
- # base class for offline training algorithm
+ # brainpy_object class for offline training algorithm
'OfflineAlgorithm',
# training methods
@@ -33,7 +32,7 @@
name2func = dict()
-class OfflineAlgorithm(BrainPyObject):
+class OfflineAlgorithm(bm.BrainPyObject):
"""Base class for offline training algorithm."""
def __init__(self, name=None):
diff --git a/brainpy/algorithms/online.py b/brainpy/algorithms/online.py
index 946735b78..8c16b1b3e 100644
--- a/brainpy/algorithms/online.py
+++ b/brainpy/algorithms/online.py
@@ -1,12 +1,11 @@
# -*- coding: utf-8 -*-
import brainpy.math as bm
-from brainpy.base import BrainPyObject
from jax import vmap
import jax.numpy as jnp
__all__ = [
- # base class
+ # brainpy_object class
'OnlineAlgorithm',
# online learning algorithms
@@ -21,7 +20,7 @@
name2func = dict()
-class OnlineAlgorithm(BrainPyObject):
+class OnlineAlgorithm(bm.BrainPyObject):
"""Base class for online training algorithm."""
def __init__(self, name=None):
diff --git a/brainpy/analysis/highdim/slow_points.py b/brainpy/analysis/highdim/slow_points.py
index 39dcd4f97..4c65831ce 100644
--- a/brainpy/analysis/highdim/slow_points.py
+++ b/brainpy/analysis/highdim/slow_points.py
@@ -13,7 +13,6 @@
import brainpy.math as bm
from brainpy import optimizers as optim, losses
from brainpy.analysis import utils, base, constants
-from brainpy.base import ArrayCollector
from brainpy.dyn.base import DynamicalSystem
from brainpy.dyn.runners import check_and_format_inputs, _f_ops
from brainpy.errors import AnalyzerError, UnsupportedError
@@ -133,11 +132,11 @@ def __init__(
# update function
if target_vars is None:
- self.target_vars = ArrayCollector()
+ self.target_vars = bm.ArrayCollector()
else:
if not isinstance(target_vars, dict):
raise TypeError(f'"target_vars" must be a dict but we got {type(target_vars)}')
- self.target_vars = ArrayCollector(target_vars)
+ self.target_vars = bm.ArrayCollector(target_vars)
excluded_vars = () if excluded_vars is None else excluded_vars
if isinstance(excluded_vars, dict):
excluded_vars = tuple(excluded_vars.values())
diff --git a/brainpy/analysis/lowdim/lowdim_analyzer.py b/brainpy/analysis/lowdim/lowdim_analyzer.py
index f57c21261..ae42f0499 100644
--- a/brainpy/analysis/lowdim/lowdim_analyzer.py
+++ b/brainpy/analysis/lowdim/lowdim_analyzer.py
@@ -12,7 +12,6 @@
from brainpy import errors, tools
from brainpy.analysis import constants as C, utils
from brainpy.analysis.base import DSAnalyzer
-from brainpy.base.collector import Collector
pyplot = None
@@ -91,7 +90,7 @@ def __init__(
if not isinstance(target_vars, dict):
raise errors.AnalyzerError('"target_vars" must be a dict, with the format of '
'{"var1": (var1_min, var1_max)}.')
- self.target_vars = Collector(target_vars)
+ self.target_vars = bm.Collector(target_vars)
self.target_var_names = list(self.target_vars.keys()) # list of target vars
for key in self.target_vars.keys():
if key not in self.model.variables:
@@ -110,7 +109,7 @@ def __init__(
for key in fixed_vars.keys():
if key not in self.model.variables:
raise ValueError(f'{key} is not a dynamical variable in {self.model}.')
- self.fixed_vars = Collector(fixed_vars)
+ self.fixed_vars = bm.Collector(fixed_vars)
# check duplicate
for key in self.fixed_vars.keys():
@@ -125,7 +124,7 @@ def __init__(
if not isinstance(pars_update, dict):
raise errors.AnalyzerError('"pars_update" must be a dict with the format '
'of {"par1": val1, "par2": val2}.')
- pars_update = Collector(pars_update)
+ pars_update = bm.Collector(pars_update)
for key in pars_update.keys():
if key not in self.model.parameters:
raise errors.AnalyzerError(f'"{key}" is not a valid parameter in "{self.model}" model.')
@@ -144,7 +143,7 @@ def __init__(
raise errors.AnalyzerError(
f'The range of parameter {key} is reversed, which means {value[0]} should be smaller than {value[1]}.')
- self.target_pars = Collector(target_pars)
+ self.target_pars = bm.Collector(target_pars)
self.target_par_names = list(self.target_pars.keys()) # list of target_pars
# check duplicate
diff --git a/brainpy/base/__init__.py b/brainpy/base/__init__.py
index 38399e8ed..64a003bf8 100644
--- a/brainpy/base/__init__.py
+++ b/brainpy/base/__init__.py
@@ -1,28 +1,14 @@
# -*- coding: utf-8 -*-
"""
-The ``base`` module for whole BrainPy ecosystem.
-
-- This module provides the most fundamental class ``BrainPyObject``,
- and its associated helper class ``Collector`` and ``ArrayCollector``.
-- For each instance of "BrainPyObject" class, users can retrieve all
- the variables (or trainable variables), integrators, and nodes.
-- This module also provides a ``FunAsObject`` class to wrap user-defined
- functions. In each function, maybe several nodes are used, and
- users can initialize a ``FunAsObject`` by providing the nodes used
- in the function. Unfortunately, ``FunAsObject`` class does not have
- the ability to gather nodes automatically.
-- This module provides ``io`` helper functions to help users save/load
- model states, or share user's customized model with others.
-- This module provides ``naming`` tools to guarantee the unique nameing
- for each BrainPyObject object.
-
-Details please see the following.
+This module is deprecated since version 2.3.1.
+Please use ``brainpy.math.*`` instead.
"""
-from brainpy.base.base import *
-from brainpy.base.collector import *
-from brainpy.base.function import *
-from brainpy.base.io import *
-from brainpy.base.naming import *
+
+from .base import *
+from .collector import *
+from .function import *
+from .io import *
+from .naming import *
diff --git a/brainpy/base/base.py b/brainpy/base/base.py
index 0d244b670..437f002b5 100644
--- a/brainpy/base/base.py
+++ b/brainpy/base/base.py
@@ -1,392 +1,14 @@
# -*- coding: utf-8 -*-
-import logging
-import os.path
-import warnings
-from collections import namedtuple
-from typing import Dict, Any, Tuple
-
-from brainpy import errors
-from brainpy.base import io, naming
-from brainpy.base.collector import Collector, ArrayCollector
-
-math = None
-StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])
+from brainpy.math.object_transform import base_object
__all__ = [
- 'BrainPyObject',
- 'Base',
+ 'BrainPyObject', 'Base',
]
-logger = logging.getLogger('brainpy.base')
-
-
-class BrainPyObject(object):
- """The BrainPyObject class for whole BrainPy ecosystem.
-
- The subclass of BrainPyObject includes:
-
- - ``DynamicalSystem`` in *brainpy.dyn.base.py*
- - ``Integrator`` in *brainpy.integrators.base.py*
- - ``FunAsObject`` in *brainpy.base.function.py*
- - ``Optimizer`` in *brainpy.optimizers.py*
- - ``Scheduler`` in *brainpy.optimizers.py*
- - and others.
- """
-
- _excluded_vars = ()
-
- def __init__(self, name=None):
- # check whether the object has a unique name.
- self._name = None
- self._name = self.unique_name(name=name)
- naming.check_name_uniqueness(name=self._name, obj=self)
-
- # Used to wrap the implicit variables
- # which cannot be accessed by self.xxx
- self.implicit_vars = ArrayCollector()
-
- # Used to wrap the implicit children nodes
- # which cannot be accessed by self.xxx
- self.implicit_nodes = Collector()
-
- @property
- def name(self):
- """Name of the model."""
- return self._name
-
- @name.setter
- def name(self, name: str = None):
- self._name = self.unique_name(name=name)
- naming.check_name_uniqueness(name=self._name, obj=self)
-
- def register_implicit_vars(self, *variables, **named_variables):
- global math
- if math is None: from brainpy import math
-
- for variable in variables:
- if isinstance(variable, math.Variable):
- self.implicit_vars[f'var{id(variable)}'] = variable
- elif isinstance(variable, (tuple, list)):
- for v in variable:
- if not isinstance(v, math.Variable):
- raise ValueError(f'Must be instance of {math.Variable.__name__}, but we got {type(v)}')
- self.implicit_vars[f'var{id(v)}'] = v
- elif isinstance(variable, dict):
- for k, v in variable.items():
- if not isinstance(v, math.Variable):
- raise ValueError(f'Must be instance of {math.Variable.__name__}, but we got {type(v)}')
- self.implicit_vars[k] = v
- else:
- raise ValueError(f'Unknown type: {type(variable)}')
- for key, variable in named_variables.items():
- if not isinstance(variable, math.Variable):
- raise ValueError(f'Must be instance of {math.Variable.__name__}, but we got {type(variable)}')
- self.implicit_vars[key] = variable
-
- def register_implicit_nodes(self, *nodes, node_cls: type = None, **named_nodes):
- if node_cls is None:
- node_cls = BrainPyObject
- for node in nodes:
- if isinstance(node, node_cls):
- self.implicit_nodes[node.name] = node
- elif isinstance(node, (tuple, list)):
- for n in node:
- if not isinstance(n, node_cls):
- raise ValueError(f'Must be instance of {node_cls.__name__}, but we got {type(n)}')
- self.implicit_nodes[n.name] = n
- elif isinstance(node, dict):
- for k, n in node.items():
- if not isinstance(n, node_cls):
- raise ValueError(f'Must be instance of {node_cls.__name__}, but we got {type(n)}')
- self.implicit_nodes[k] = n
- else:
- raise ValueError(f'Unknown type: {type(node)}')
- for key, node in named_nodes.items():
- if not isinstance(node, node_cls):
- raise ValueError(f'Must be instance of {node_cls.__name__}, but we got {type(node)}')
- self.implicit_nodes[key] = node
-
- def vars(self,
- method: str = 'absolute',
- level: int = -1,
- include_self: bool = True,
- exclude_types: Tuple[type, ...] = None):
- """Collect all variables in this node and the children nodes.
-
- Parameters
- ----------
- method : str
- The method to access the variables.
- level: int
- The hierarchy level to find variables.
- include_self: bool
- Whether include the variables in the self.
- exclude_types: tuple of type
- The type to exclude.
-
- Returns
- -------
- gather : ArrayCollector
- The collection contained (the path, the variable).
- """
- global math
- if math is None: from brainpy import math
-
- if exclude_types is None:
- exclude_types = (math.VariableView, )
- nodes = self.nodes(method=method, level=level, include_self=include_self)
- gather = ArrayCollector()
- for node_path, node in nodes.items():
- for k in dir(node):
- v = getattr(node, k)
- include = False
- if isinstance(v, math.Variable):
- include = True
- if len(exclude_types) and isinstance(v, exclude_types):
- include = False
- if include:
- if k not in node._excluded_vars:
- gather[f'{node_path}.{k}' if node_path else k] = v
- gather.update({f'{node_path}.{k}': v for k, v in node.implicit_vars.items()})
- return gather
-
- def train_vars(self, method='absolute', level=-1, include_self=True):
- """The shortcut for retrieving all trainable variables.
-
- Parameters
- ----------
- method : str
- The method to access the variables. Support 'absolute' and 'relative'.
- level: int
- The hierarchy level to find TrainVar instances.
- include_self: bool
- Whether include the TrainVar instances in the self.
-
- Returns
- -------
- gather : ArrayCollector
- The collection contained (the path, the trainable variable).
- """
- global math
- if math is None: from brainpy import math
- return self.vars(method=method, level=level, include_self=include_self).subset(math.TrainVar)
-
- def _find_nodes(self, method='absolute', level=-1, include_self=True, _lid=0, _paths=None):
- if _paths is None:
- _paths = set()
- gather = Collector()
- if include_self:
- if method == 'absolute':
- gather[self.name] = self
- elif method == 'relative':
- gather[''] = self
- else:
- raise ValueError(f'No support for the method of "{method}".')
- if (level > -1) and (_lid >= level):
- return gather
- if method == 'absolute':
- nodes = []
- for k, v in self.__dict__.items():
- if isinstance(v, BrainPyObject):
- path = (id(self), id(v))
- if path not in _paths:
- _paths.add(path)
- gather[v.name] = v
- nodes.append(v)
- for node in self.implicit_nodes.values():
- path = (id(self), id(node))
- if path not in _paths:
- _paths.add(path)
- gather[node.name] = node
- nodes.append(node)
- for v in nodes:
- gather.update(v._find_nodes(method=method,
- level=level,
- _lid=_lid + 1,
- _paths=_paths,
- include_self=include_self))
-
- elif method == 'relative':
- nodes = []
- for k, v in self.__dict__.items():
- if isinstance(v, BrainPyObject):
- path = (id(self), id(v))
- if path not in _paths:
- _paths.add(path)
- gather[k] = v
- nodes.append((k, v))
- for key, node in self.implicit_nodes.items():
- path = (id(self), id(node))
- if path not in _paths:
- _paths.add(path)
- gather[key] = node
- nodes.append((key, node))
- for k1, v1 in nodes:
- for k2, v2 in v1._find_nodes(method=method,
- _paths=_paths,
- _lid=_lid + 1,
- level=level,
- include_self=include_self).items():
- if k2: gather[f'{k1}.{k2}'] = v2
-
- else:
- raise ValueError(f'No support for the method of "{method}".')
- return gather
-
- def nodes(self, method='absolute', level=-1, include_self=True):
- """Collect all children nodes.
-
- Parameters
- ----------
- method : str
- The method to access the nodes.
- level: int
- The hierarchy level to find nodes.
- include_self: bool
- Whether include the self.
-
- Returns
- -------
- gather : Collector
- The collection contained (the path, the node).
- """
- return self._find_nodes(method=method, level=level, include_self=include_self)
-
- def unique_name(self, name=None, type_=None):
- """Get the unique name for this object.
-
- Parameters
- ----------
- name : str, optional
- The expected name. If None, the default unique name will be returned.
- Otherwise, the provided name will be checked to guarantee its uniqueness.
- type_ : str, optional
- The name of this class, used for object naming.
-
- Returns
- -------
- name : str
- The unique name for this object.
- """
- if name is None:
- if type_ is None:
- return naming.get_unique_name(type_=self.__class__.__name__)
- else:
- return naming.get_unique_name(type_=type_)
- else:
- naming.check_name_uniqueness(name=name, obj=self)
- return name
-
- def load_states(self, filename, verbose=False):
- """Load the model states.
-
- Parameters
- ----------
- filename : str
- The filename which stores the model states.
- verbose: bool
- Whether report the load progress.
- """
- if not os.path.exists(filename):
- raise errors.BrainPyError(f'Cannot find the file path: {filename}')
- elif filename.endswith('.hdf5') or filename.endswith('.h5'):
- io.load_by_h5(filename, target=self, verbose=verbose)
- elif filename.endswith('.pkl'):
- io.load_by_pkl(filename, target=self, verbose=verbose)
- elif filename.endswith('.npz'):
- io.load_by_npz(filename, target=self, verbose=verbose)
- elif filename.endswith('.mat'):
- io.load_by_mat(filename, target=self, verbose=verbose)
- else:
- raise errors.BrainPyError(f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}')
-
- def save_states(self, filename, variables=None, **setting):
- """Save the model states.
-
- Parameters
- ----------
- filename : str
- The file name which to store the model states.
- variables: optional, dict, ArrayCollector
- The variables to save. If not provided, all variables retrieved by ``~.vars()`` will be used.
- """
- if variables is None:
- variables = self.vars(method='absolute', level=-1)
-
- if filename.endswith('.hdf5') or filename.endswith('.h5'):
- io.save_as_h5(filename, variables=variables)
- elif filename.endswith('.pkl') or filename.endswith('.pickle'):
- io.save_as_pkl(filename, variables=variables)
- elif filename.endswith('.npz'):
- io.save_as_npz(filename, variables=variables, **setting)
- elif filename.endswith('.mat'):
- io.save_as_mat(filename, variables=variables)
- else:
- raise errors.BrainPyError(f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}')
-
- def state_dict(self):
- """Returns a dictionary containing a whole state of the module.
-
- Returns
- -------
- out: dict
- A dictionary containing a whole state of the module.
- """
- return self.vars().unique().dict()
-
- def load_state_dict(self, state_dict: Dict[str, Any], warn: bool = True):
- """Copy parameters and buffers from :attr:`state_dict` into
- this module and its descendants.
-
- Parameters
- ----------
- state_dict: dict
- A dict containing parameters and persistent buffers.
- warn: bool
- Warnings when there are missing keys or unexpected keys in the external ``state_dict``.
-
- Returns
- -------
- out: StateLoadResult
- ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
-
- * **missing_keys** is a list of str containing the missing keys
- * **unexpected_keys** is a list of str containing the unexpected keys
- """
- variables = self.vars().unique()
- keys1 = set(state_dict.keys())
- keys2 = set(variables.keys())
- unexpected_keys = list(keys1 - keys2)
- missing_keys = list(keys2 - keys1)
- for key in keys2.intersection(keys1):
- variables[key].value = state_dict[key]
- if warn:
- if len(unexpected_keys):
- warnings.warn(f'Unexpected keys in state_dict: {unexpected_keys}', UserWarning)
- if len(missing_keys):
- warnings.warn(f'Missing keys in state_dict: {missing_keys}', UserWarning)
- return StateLoadResult(missing_keys, unexpected_keys)
-
- # def to(self, devices):
- # global math
- # if math is None: from brainpy import math
- #
- # def cpu(self):
- # global math
- # if math is None: from brainpy import math
- #
- # all_vars = self.vars().unique()
- # for data in all_vars.values():
- # data[:] = math.asarray(data.value)
- #
- # def cuda(self):
- # global math
- # if math is None: from brainpy import math
- #
- # def tpu(self):
- # global math
- # if math is None: from brainpy import math
+# use `brainpy.math.BrainPyObject` instead
+BrainPyObject = base_object.BrainPyObject
+# use `brainpy.math.Base` instead
+Base = base_object.Base
-Base = BrainPyObject
diff --git a/brainpy/base/collector.py b/brainpy/base/collector.py
index 0e9533b95..edd2b4435 100644
--- a/brainpy/base/collector.py
+++ b/brainpy/base/collector.py
@@ -1,12 +1,6 @@
# -*- coding: utf-8 -*-
-
-from typing import Dict, Sequence, Union
-
-from jax.tree_util import register_pytree_node
-from jax.util import safe_zip
-
-math = None
+from brainpy.math.object_transform import collector
__all__ = [
'Collector',
@@ -14,212 +8,7 @@
'TensorCollector',
]
+Collector = collector.Collector
+ArrayCollector = collector.ArrayCollector
+TensorCollector = collector.TensorCollector
-class Collector(dict):
- """A Collector is a dictionary (name, var) with some additional methods to make manipulation
- of collections of variables easy. A Collector is ordered by insertion order. It is the object
- returned by BrainPyObject.vars() and used as input in many Collector instance: optimizers, jit, etc..."""
-
- def __setitem__(self, key, value):
- """Overload bracket assignment to catch potential conflicts during assignment."""
- if key in self:
- if id(self[key]) != id(value):
- raise ValueError(f'Name "{key}" conflicts: same name for {value} and {self[key]}.')
- dict.__setitem__(self, key, value)
-
- def replace(self, key, new_value):
- """Replace the original key with the new value."""
- self.pop(key)
- self[key] = new_value
-
- def update(self, other, **kwargs):
- assert isinstance(other, (dict, list, tuple))
- if isinstance(other, dict):
- for key, value in other.items():
- self[key] = value
- elif isinstance(other, (tuple, list)):
- num = len(self)
- for i, value in enumerate(other):
- self[f'_var{i + num}'] = value
- else:
- raise ValueError(f'Only supports dict/list/tuple, but we got {type(other)}')
- for key, value in kwargs.items():
- self[key] = value
- return self
-
- def __add__(self, other):
- """Merging two dicts.
-
- Parameters
- ----------
- other: dict
- The other dict instance.
-
- Returns
- -------
- gather: Collector
- The new collector.
- """
- gather = type(self)(self)
- gather.update(other)
- return gather
-
- def __sub__(self, other: Union[Dict, Sequence]):
- """Remove other item in the collector.
-
- Parameters
- ----------
- other: dict, sequence
- The items to remove.
-
- Returns
- -------
- gather: Collector
- The new collector.
- """
- if not isinstance(other, (dict, tuple, list)):
- raise ValueError(f'Only support dict/tuple/list, but we got {type(other)}.')
- gather = type(self)(self)
- if isinstance(other, dict):
- for key, val in other.items():
- if key in gather:
- if id(val) != id(gather[key]):
- raise ValueError(f'Cannot remove {key}, because we got two different values: '
- f'{val} != {gather[key]}')
- gather.pop(key)
- else:
- raise ValueError(f'Cannot remove {key}, because we do not find it '
- f'in {self.keys()}.')
- elif isinstance(other, (list, tuple)):
- id_to_keys = {}
- for k, v in self.items():
- id_ = id(v)
- if id_ not in id_to_keys:
- id_to_keys[id_] = []
- id_to_keys[id_].append(k)
-
- keys_to_remove = []
- for key in other:
- if isinstance(key, str):
- keys_to_remove.append(key)
- else:
- keys_to_remove.extend(id_to_keys[id(key)])
-
- for key in set(keys_to_remove):
- if key in gather:
- gather.pop(key)
- else:
- raise ValueError(f'Cannot remove {key}, because we do not find it '
- f'in {self.keys()}.')
- else:
- raise KeyError(f'Unknown type of "other". Only support dict/tuple/list, but we got {type(other)}')
- return gather
-
- def subset(self, var_type):
- """Get the subset of the (key, value) pair.
-
- ``subset()`` can be used to get a subset of some class:
-
- >>> import brainpy as bp
- >>>
- >>> some_collector = Collector()
- >>>
- >>> # get all trainable variables
- >>> some_collector.subset(bp.math.TrainVar)
- >>>
- >>> # get all Variable
- >>> some_collector.subset(bp.math.Variable)
-
- or, it can be used to get a subset of integrators:
-
- >>> # get all ODE integrators
- >>> some_collector.subset(bp.ode.ODEIntegrator)
-
- Parameters
- ----------
- var_type : type
- The type/class to match.
- """
- gather = type(self)()
- for key, value in self.items():
- if isinstance(value, var_type):
- gather[key] = value
- return gather
-
- def unique(self):
- """Get a new type of collector with unique values.
-
- If one value is assigned to two or more keys,
- then only one pair of (key, value) will be returned.
- """
- gather = type(self)()
- seen = set()
- for k, v in self.items():
- if id(v) not in seen:
- seen.add(id(v))
- gather[k] = v
- return gather
-
-
-class ArrayCollector(Collector):
- """A ArrayCollector is a dictionary (name, var)
- with some additional methods to make manipulation
- of collections of variables easy. A Collection
- is ordered by insertion order. It is the object
- returned by DynamicalSystem.vars() and used as input
- in many DynamicalSystem instance: optimizers, Jit, etc..."""
-
- def __setitem__(self, key, value):
- """Overload bracket assignment to catch potential conflicts during assignment."""
- global math
- if math is None: from brainpy import math
-
- assert isinstance(value, math.ndarray)
- if key in self:
- if id(self[key]) != id(value):
- raise ValueError(f'Name "{key}" conflicts: same name for {value} and {self[key]}.')
- dict.__setitem__(self, key, value)
-
- def assign(self, inputs):
- """Assign data to all values.
-
- Parameters
- ----------
- inputs : dict
- The data for each value in this collector.
- """
- if len(self) != len(inputs):
- raise ValueError(f'The target has {len(inputs)} data, while we got '
- f'an input value with the length of {len(inputs)}.')
- for key, v in self.items():
- v.value = inputs[key]
-
- def dict(self):
- """Get a dict with the key and the value data.
- """
- gather = dict()
- for k, v in self.items():
- gather[k] = v.value
- return gather
-
- def data(self):
- """Get all data in each value."""
- return [x.value for x in self.values()]
-
- @classmethod
- def from_other(cls, other: Union[Sequence, Dict]):
- if isinstance(other, (tuple, list)):
- return cls({id(o): o for o in other})
- elif isinstance(other, dict):
- return cls(other)
- else:
- raise TypeError
-
-
-TensorCollector = ArrayCollector
-
-register_pytree_node(
- ArrayCollector,
- lambda x: (x.values(), x.keys()),
- lambda keys, values: ArrayCollector(safe_zip(keys, values))
-)
diff --git a/brainpy/base/function.py b/brainpy/base/function.py
index 0eedafc76..a7e31018b 100644
--- a/brainpy/base/function.py
+++ b/brainpy/base/function.py
@@ -1,54 +1,12 @@
# -*- coding: utf-8 -*-
-from typing import Callable, Sequence, Dict, Union, TypeVar
-
-from brainpy.base.base import BrainPyObject
-
-
-Variable = TypeVar('Variable')
-
+from brainpy.math.object_transform import base_object
__all__ = [
'FunAsObject',
+ 'Function',
]
+FunAsObject = base_object.FunAsObject
+Function = base_object.FunAsObject
-class FunAsObject(BrainPyObject):
- """Transform a Python function as a :py:class:`~.BrainPyObject`.
-
- Parameters
- ----------
- f : callable
- The function to wrap.
- child_objs : optional, BrainPyObject, sequence of BrainPyObject, dict
- The nodes in the defined function ``f``.
- dyn_vars : optional, Variable, sequence of Variable, dict
- The dynamically changed variables.
- name : optional, str
- The function name.
- """
-
- def __init__(self,
- f: Callable,
- child_objs: Union[BrainPyObject, Sequence[BrainPyObject], Dict[dict, BrainPyObject]] = None,
- dyn_vars: Union[Variable, Sequence[Variable], Dict[dict, Variable]] = None,
- name: str = None):
- super(FunAsObject, self).__init__(name=name)
- self._f = f
- if child_objs is not None:
- self.register_implicit_nodes(child_objs)
- if dyn_vars is not None:
- self.register_implicit_vars(dyn_vars)
-
- def __call__(self, *args, **kwargs):
- return self._f(*args, **kwargs)
-
- def __repr__(self) -> str:
- from brainpy.tools import repr_context
- name = self.__class__.__name__
- indent = " " * (len(name) + 1)
- indent2 = indent + " " * len('nodes=')
- nodes = [repr_context(str(n), indent2) for n in self.implicit_nodes.values()]
- node_string = ", \n".join(nodes)
- return (f'{name}(nodes=[{node_string}],\n' +
- " " * (len(name) + 1) + f'num_of_vars={len(self.implicit_vars)})')
diff --git a/brainpy/base/io.py b/brainpy/base/io.py
index cbddcf1d4..060eeaec2 100644
--- a/brainpy/base/io.py
+++ b/brainpy/base/io.py
@@ -1,392 +1,24 @@
# -*- coding: utf-8 -*-
-from typing import Dict, Type, Union, Tuple, List
-import logging
-import pickle
-
-import numpy as np
-
-from brainpy import errors
-from brainpy.base.collector import ArrayCollector
-
-logger = logging.getLogger('brainpy.base.io')
+from brainpy import checkpoints
__all__ = [
'SUPPORTED_FORMATS',
- 'save_as_h5',
- 'save_as_npz',
- 'save_as_pkl',
- 'save_as_mat',
- 'load_by_h5',
- 'load_by_npz',
- 'load_by_pkl',
- 'load_by_mat',
+ 'save_as_h5', 'load_by_h5',
+ 'save_as_npz', 'load_by_npz',
+ 'save_as_pkl', 'load_by_pkl',
+ 'save_as_mat', 'load_by_mat',
]
-SUPPORTED_FORMATS = ['.h5', '.hdf5', '.npz', '.pkl', '.mat']
-
-
-def check_dict_data(
- a_dict: Dict,
- key_type: Union[Type, Tuple[Type, ...]] = None,
- val_type: Union[Type, Tuple[Type, ...]] = None,
- name: str = None
-):
- """Check the dict data."""
- name = '' if (name is None) else f'"{name}"'
- if not isinstance(a_dict, dict):
- raise ValueError(f'{name} must be a dict, while we got {type(a_dict)}')
- if key_type is not None:
- for key, value in a_dict.items():
- if not isinstance(key, key_type):
- raise ValueError(f'{name} must be a dict of ({key_type}, {val_type}), '
- f'while we got ({type(key)}, {type(value)})')
- if val_type is not None:
- for key, value in a_dict.items():
- if not isinstance(value, val_type):
- raise ValueError(f'{name} must be a dict of ({key_type}, {val_type}), '
- f'while we got ({type(key)}, {type(value)})')
-
-
-def _check_module(module, module_name, ext):
- """Check whether the required module is installed."""
- if module is None:
- raise errors.PackageMissingError(
- '"{package}" must be installed when you want to save/load data with {ext} '
- 'format. \nPlease install {package} through "pip install {package}" or '
- '"conda install {package}".'.format(package=module_name, ext=ext)
- )
-
-
-def _check_missing(variables, filename):
- if len(variables):
- logger.warning(f'There are variable states missed in {filename}. '
- f'The missed variables are: {list(variables.keys())}.')
-
-
-def _check_target(target):
- from .base import BrainPyObject
- if not isinstance(target, BrainPyObject):
- raise TypeError(f'"target" must be instance of "{BrainPyObject.__name__}", but we got {type(target)}')
-
-
-not_found_msg = ('"{key}" is stored in {filename}. But we does '
- 'not find it is defined as variable in {target}.')
-id_mismatch_msg = ('{key1} and {key2} is the same data in {filename}. '
- 'But we found they are different in {target}.')
-
-DUPLICATE_KEY = 'duplicate_keys'
-DUPLICATE_TARGET = 'duplicate_targets'
-
-
-def _load(
- target,
- verbose: bool,
- filename: str,
- load_vars: dict,
- duplicates: Tuple[List[str], List[str]],
- remove_first_axis: bool = False
-):
- from brainpy import math as bm
-
- # get variables
- _check_target(target)
- variables = target.vars(method='absolute', level=-1)
- var_names_in_obj = list(variables.keys())
-
- # read data from file
- for key in load_vars.keys():
- if verbose:
- print(f'Loading {key} ...')
- if key not in variables:
- raise KeyError(not_found_msg.format(key=key, target=target.name, filename=filename))
- if remove_first_axis:
- value = load_vars[key][0]
- else:
- value = load_vars[key]
- variables[key].value = bm.asarray(value)
- var_names_in_obj.remove(key)
-
- # check duplicate names
- duplicate_keys = duplicates[0]
- duplicate_targets = duplicates[1]
- for key1, key2 in zip(duplicate_keys, duplicate_targets):
- if key1 not in var_names_in_obj:
- raise KeyError(not_found_msg.format(key=key1, target=target.name, filename=filename))
- if id(variables[key1]) != id(variables[key2]):
- raise ValueError(id_mismatch_msg.format(key1=key1, key2=target, filename=filename, target=target.name))
- var_names_in_obj.remove(key1)
-
- # check missing names
- if len(var_names_in_obj):
- logger.warning(f'There are variable states missed in {filename}. '
- f'The missed variables are: {var_names_in_obj}.')
-
-
-def _unique_and_duplicate(collector: dict):
- gather = ArrayCollector()
- id2name = dict()
- duplicates = ([], [])
- for k, v in collector.items():
- id_ = id(v)
- if id_ not in id2name:
- gather[k] = v
- id2name[id_] = k
- else:
- k2 = id2name[id_]
- duplicates[0].append(k)
- duplicates[1].append(k2)
- duplicates = (duplicates[0], duplicates[1])
- return gather, duplicates
-
-
-def save_as_h5(filename: str, variables: dict):
- """Save variables into a HDF5 file.
-
- Parameters
- ----------
- filename: str
- The filename to save.
- variables: dict
- All variables to save.
- """
- if not (filename.endswith('.hdf5') or filename.endswith('.h5')):
- raise ValueError(f'Cannot save variables as a HDF5 file. We only support file with '
- f'postfix of ".hdf5" and ".h5". But we got {filename}')
-
- from brainpy import math as bm
- import h5py
-
- # check variables
- check_dict_data(variables, name='variables')
- variables, duplicates = _unique_and_duplicate(variables)
-
- # save
- f = h5py.File(filename, "w")
- for key, data in variables.items():
- f[key] = bm.as_numpy(data)
- if len(duplicates[0]):
- f.create_dataset(DUPLICATE_TARGET, data='+'.join(duplicates[1]))
- f.create_dataset(DUPLICATE_KEY, data='+'.join(duplicates[0]))
- f.close()
-
-
-def load_by_h5(filename: str, target, verbose: bool = False):
- """Load variables in a HDF5 file.
-
- Parameters
- ----------
- filename: str
- The filename to load variables.
- target: BrainPyObject
- The instance of :py:class:`~.brainpy.BrainPyObject`.
- verbose: bool
- Whether report the load progress.
- """
- if not (filename.endswith('.hdf5') or filename.endswith('.h5')):
- raise ValueError(f'Cannot load variables from a HDF5 file. We only support file with '
- f'postfix of ".hdf5" and ".h5". But we got {filename}')
-
- # read data
- import h5py
- load_vars = dict()
- with h5py.File(filename, "r") as f:
- for key in f.keys():
- if key in [DUPLICATE_KEY, DUPLICATE_TARGET]: continue
- load_vars[key] = np.asarray(f[key])
- if DUPLICATE_KEY in f:
- duplicate_keys = np.asarray(f[DUPLICATE_KEY]).item().decode("utf-8").split('+')
- duplicate_targets = np.asarray(f[DUPLICATE_TARGET]).item().decode("utf-8").split('+')
- duplicates = (duplicate_keys, duplicate_targets)
- else:
- duplicates = ([], [])
-
- # assign values
- _load(target, verbose, filename, load_vars, duplicates)
-
-
-def save_as_npz(filename, variables, compressed=False):
- """Save variables into a numpy file.
-
- Parameters
- ----------
- filename: str
- The filename to store.
- variables: dict
- Variables to save.
- compressed: bool
- Whether we use the compressed mode.
- """
- if not filename.endswith('.npz'):
- raise ValueError(f'Cannot save variables as a .npz file. We only support file with '
- f'postfix of ".npz". But we got {filename}')
-
- from brainpy import math as bm
- check_dict_data(variables, name='variables')
- variables, duplicates = _unique_and_duplicate(variables)
-
- # save
- variables = {k: bm.as_numpy(v) for k, v in variables.items()}
- if len(duplicates[0]):
- variables[DUPLICATE_KEY] = np.asarray(duplicates[0])
- variables[DUPLICATE_TARGET] = np.asarray(duplicates[1])
- if compressed:
- np.savez_compressed(filename, **variables)
- else:
- np.savez(filename, **variables)
-
-
-def load_by_npz(filename, target, verbose=False):
- """Load variables from a numpy file.
-
- Parameters
- ----------
- filename: str
- The filename to load variables.
- target: BrainPyObject
- The instance of :py:class:`~.brainpy.BrainPyObject`.
- verbose: bool
- Whether report the load progress.
- """
- if not filename.endswith('.npz'):
- raise ValueError(f'Cannot load variables from a .npz file. We only support file with '
- f'postfix of ".npz". But we got {filename}')
-
- # load data
- load_vars = dict()
- all_data = np.load(filename)
- for key in all_data.files:
- if key in [DUPLICATE_KEY, DUPLICATE_TARGET]: continue
- load_vars[key] = all_data[key]
- if DUPLICATE_KEY in all_data:
- duplicate_keys = all_data[DUPLICATE_KEY].tolist()
- duplicate_targets = all_data[DUPLICATE_TARGET].tolist()
- duplicates = (duplicate_keys, duplicate_targets)
- else:
- duplicates = ([], [])
-
- # assign values
- _load(target, verbose, filename, load_vars, duplicates)
-
-
-def save_as_pkl(filename, variables):
- """Save variables into a pickle file.
-
- Parameters
- ----------
- filename: str
- The filename to save.
- variables: dict
- All variables to save.
- """
- if not (filename.endswith('.pkl') or filename.endswith('.pickle')):
- raise ValueError(f'Cannot save variables into a pickle file. We only support file with '
- f'postfix of ".pkl" and ".pickle". But we got {filename}')
-
- check_dict_data(variables, name='variables')
- variables, duplicates = _unique_and_duplicate(variables)
- import brainpy.math as bm
- targets = {k: bm.as_numpy(v) for k, v in variables.items()}
- if len(duplicates[0]) > 0:
- targets[DUPLICATE_KEY] = np.asarray(duplicates[0])
- targets[DUPLICATE_TARGET] = np.asarray(duplicates[1])
- with open(filename, 'wb') as f:
- pickle.dump(targets, f, protocol=pickle.HIGHEST_PROTOCOL)
-
-
-def load_by_pkl(filename, target, verbose=False):
- """Load variables from a pickle file.
-
- Parameters
- ----------
- filename: str
- The filename to load variables.
- target: BrainPyObject
- The instance of :py:class:`~.brainpy.BrainPyObject`.
- verbose: bool
- Whether report the load progress.
- """
- if not (filename.endswith('.pkl') or filename.endswith('.pickle')):
- raise ValueError(f'Cannot load variables from a pickle file. We only support file with '
- f'postfix of ".pkl" and ".pickle". But we got {filename}')
-
- # load variables
- load_vars = dict()
- with open(filename, 'rb') as f:
- all_data = pickle.load(f)
- for key, data in all_data.items():
- if key in [DUPLICATE_KEY, DUPLICATE_TARGET]: continue
- load_vars[key] = data
- if DUPLICATE_KEY in all_data:
- duplicate_keys = all_data[DUPLICATE_KEY].tolist()
- duplicate_targets = all_data[DUPLICATE_TARGET].tolist()
- duplicates = (duplicate_keys, duplicate_targets)
- else:
- duplicates = ([], [])
-
- # assign data
- _load(target, verbose, filename, load_vars, duplicates)
-
-
-def save_as_mat(filename, variables):
- """Save variables into a matlab file.
-
- Parameters
- ----------
- filename: str
- The filename to save.
- variables: dict
- All variables to save.
- """
- if not filename.endswith('.mat'):
- raise ValueError(f'Cannot save variables into a .mat file. We only support file with '
- f'postfix of ".mat". But we got {filename}')
-
- from brainpy import math as bm
- import scipy.io as sio
-
- check_dict_data(variables, name='variables')
- variables, duplicates = _unique_and_duplicate(variables)
- variables = {k: np.expand_dims(bm.as_numpy(v), axis=0) for k, v in variables.items()}
- if len(duplicates[0]):
- variables[DUPLICATE_KEY] = np.expand_dims(np.asarray(duplicates[0]), axis=0)
- variables[DUPLICATE_TARGET] = np.expand_dims(np.asarray(duplicates[1]), axis=0)
- sio.savemat(filename, variables)
-
-
-def load_by_mat(filename, target, verbose=False):
- """Load variables from a numpy file.
-
- Parameters
- ----------
- filename: str
- The filename to load variables.
- target: BrainPyObject
- The instance of :py:class:`~.brainpy.BrainPyObject`.
- verbose: bool
- Whether report the load progress.
- """
- if not filename.endswith('.mat'):
- raise ValueError(f'Cannot load variables from a .mat file. We only support file with '
- f'postfix of ".mat". But we got {filename}')
- import scipy.io as sio
+SUPPORTED_FORMATS = checkpoints.SUPPORTED_FORMATS
+save_as_h5 = checkpoints.save_as_h5
+load_by_h5 = checkpoints.load_by_h5
+save_as_npz = checkpoints.save_as_npz
+load_by_npz = checkpoints.load_by_npz
+save_as_pkl = checkpoints.save_as_pkl
+load_by_pkl = checkpoints.load_by_pkl
+save_as_mat = checkpoints.save_as_mat
+load_by_mat = checkpoints.load_by_mat
- # load data
- load_vars = dict()
- all_data = sio.loadmat(filename)
- for key, data in all_data.items():
- if key.startswith('__'):
- continue
- if key in [DUPLICATE_KEY, DUPLICATE_TARGET]:
- continue
- load_vars[key] = data[0]
- if DUPLICATE_KEY in all_data:
- duplicate_keys = [a.strip() for a in all_data[DUPLICATE_KEY].tolist()[0]]
- duplicate_targets = [a.strip() for a in all_data[DUPLICATE_TARGET].tolist()[0]]
- duplicates = (duplicate_keys, duplicate_targets)
- else:
- duplicates = ([], [])
- # assign values
- _load(target, verbose, filename, load_vars, duplicates)
diff --git a/brainpy/base/naming.py b/brainpy/base/naming.py
index 62e2542df..e854f92ed 100644
--- a/brainpy/base/naming.py
+++ b/brainpy/base/naming.py
@@ -1,10 +1,6 @@
# -*- coding: utf-8 -*-
-import logging
-
-from brainpy import errors
-
-logger = logging.getLogger('brainpy.base.naming')
+from brainpy.math.object_transform import base_object
__all__ = [
'check_name_uniqueness',
@@ -12,41 +8,9 @@
'clear_name_cache',
]
-_name2id = dict()
-_typed_names = {}
-
-
-def check_name_uniqueness(name, obj):
- """Check the uniqueness of the name for the object type."""
- if not name.isidentifier():
- raise errors.BrainPyError(f'"{name}" isn\'t a valid identifier '
- f'according to Python language definition. '
- f'Please choose another name.')
- if name in _name2id:
- if _name2id[name] != id(obj):
- raise errors.UniqueNameError(
- f'In BrainPy, each object should have a unique name. '
- f'However, we detect that {obj} has a used name "{name}". \n'
- f'If you try to run multiple trials, you may need \n\n'
- f'>>> brainpy.base.clear_name_cache() \n\n'
- f'to clear all cached names. '
- )
- else:
- _name2id[name] = id(obj)
-
-def get_unique_name(type_):
- """Get the unique name for the given object type."""
- if type_ not in _typed_names:
- _typed_names[type_] = 0
- name = f'{type_}{_typed_names[type_]}'
- _typed_names[type_] += 1
- return name
+check_name_uniqueness = base_object.check_name_uniqueness
+get_unique_name = base_object.get_unique_name
+clear_name_cache = base_object.clear_name_cache
-def clear_name_cache(ignore_warn=False):
- """Clear the cached names."""
- _name2id.clear()
- _typed_names.clear()
- if not ignore_warn:
- logger.warning(f'All named models and their ids are cleared.')
diff --git a/brainpy/check.py b/brainpy/check.py
index eff430b5f..d69cac3d8 100644
--- a/brainpy/check.py
+++ b/brainpy/check.py
@@ -3,7 +3,7 @@
from functools import wraps, partial
from typing import Union, Sequence, Dict, Callable, Tuple, Type, Optional, Any
-from jax import numpy as jnp, tree_util as jtu
+from jax import numpy as jnp
import numpy as np
import numpy as onp
from jax.experimental.host_callback import id_tap
@@ -550,8 +550,10 @@ def is_all_vars(dyn_vars: Any, out_as: str = 'tuple'):
def is_all_objs(targets: Any, out_as: str = 'tuple'):
- from brainpy.base import BrainPyObject
- return is_elem_or_seq_or_dict(targets, BrainPyObject, out_as)
+ global bm
+ if bm is None:
+ from brainpy import math as bm
+ return is_elem_or_seq_or_dict(targets, bm.BrainPyObject, out_as)
def _err_jit_true_branch(err_fun, x):
diff --git a/brainpy/checkpoints/__init__.py b/brainpy/checkpoints/__init__.py
new file mode 100644
index 000000000..797dfe1a5
--- /dev/null
+++ b/brainpy/checkpoints/__init__.py
@@ -0,0 +1,5 @@
+# -*- coding: utf-8 -*-
+
+from .serialization import *
+from .io import *
+
diff --git a/brainpy/checkpoints/io.py b/brainpy/checkpoints/io.py
new file mode 100644
index 000000000..71d39c31c
--- /dev/null
+++ b/brainpy/checkpoints/io.py
@@ -0,0 +1,383 @@
+# -*- coding: utf-8 -*-
+
+from typing import Dict, Type, Union, Tuple, List
+import logging
+import pickle
+
+import numpy as np
+
+from brainpy import errors
+import brainpy.math as bm
+
+
+logger = logging.getLogger('brainpy.brainpy_object.io')
+
+__all__ = [
+ 'SUPPORTED_FORMATS',
+ 'save_as_h5', 'load_by_h5',
+ 'save_as_npz', 'load_by_npz',
+ 'save_as_pkl', 'load_by_pkl',
+ 'save_as_mat', 'load_by_mat',
+]
+
+SUPPORTED_FORMATS = ['.h5', '.hdf5', '.npz', '.pkl', '.mat']
+
+
+def check_dict_data(
+ a_dict: Dict,
+ key_type: Union[Type, Tuple[Type, ...]] = None,
+ val_type: Union[Type, Tuple[Type, ...]] = None,
+ name: str = None
+):
+ """Check the dict data."""
+ name = '' if (name is None) else f'"{name}"'
+ if not isinstance(a_dict, dict):
+ raise ValueError(f'{name} must be a dict, while we got {type(a_dict)}')
+ if key_type is not None:
+ for key, value in a_dict.items():
+ if not isinstance(key, key_type):
+ raise ValueError(f'{name} must be a dict of ({key_type}, {val_type}), '
+ f'while we got ({type(key)}, {type(value)})')
+ if val_type is not None:
+ for key, value in a_dict.items():
+ if not isinstance(value, val_type):
+ raise ValueError(f'{name} must be a dict of ({key_type}, {val_type}), '
+ f'while we got ({type(key)}, {type(value)})')
+
+
+def _check_module(module, module_name, ext):
+ """Check whether the required module is installed."""
+ if module is None:
+ raise errors.PackageMissingError(
+ '"{package}" must be installed when you want to save/load data with {ext} '
+ 'format. \nPlease install {package} through "pip install {package}" or '
+ '"conda install {package}".'.format(package=module_name, ext=ext)
+ )
+
+
+def _check_missing(variables, filename):
+ if len(variables):
+ logger.warning(f'There are variable states missed in {filename}. '
+ f'The missed variables are: {list(variables.keys())}.')
+
+
+def _check_target(target):
+ if not isinstance(target, bm.BrainPyObject):
+ raise TypeError(f'"target" must be instance of "{ bm.BrainPyObject.__name__}", but we got {type(target)}')
+
+
+not_found_msg = ('"{key}" is stored in {filename}. But we does '
+ 'not find it is defined as variable in {target}.')
+id_mismatch_msg = ('{key1} and {key2} is the same data in {filename}. '
+ 'But we found they are different in {target}.')
+
+DUPLICATE_KEY = 'duplicate_keys'
+DUPLICATE_TARGET = 'duplicate_targets'
+
+
+def _load(
+ target,
+ verbose: bool,
+ filename: str,
+ load_vars: dict,
+ duplicates: Tuple[List[str], List[str]],
+ remove_first_axis: bool = False
+):
+
+ # get variables
+ _check_target(target)
+ variables = target.vars(method='absolute', level=-1)
+ var_names_in_obj = list(variables.keys())
+
+ # read data from file
+ for key in load_vars.keys():
+ if verbose:
+ print(f'Loading {key} ...')
+ if key not in variables:
+ raise KeyError(not_found_msg.format(key=key, target=target.name, filename=filename))
+ if remove_first_axis:
+ value = load_vars[key][0]
+ else:
+ value = load_vars[key]
+ variables[key].value = bm.asarray(value)
+ var_names_in_obj.remove(key)
+
+ # check duplicate names
+ duplicate_keys = duplicates[0]
+ duplicate_targets = duplicates[1]
+ for key1, key2 in zip(duplicate_keys, duplicate_targets):
+ if key1 not in var_names_in_obj:
+ raise KeyError(not_found_msg.format(key=key1, target=target.name, filename=filename))
+ if id(variables[key1]) != id(variables[key2]):
+ raise ValueError(id_mismatch_msg.format(key1=key1, key2=target, filename=filename, target=target.name))
+ var_names_in_obj.remove(key1)
+
+ # check missing names
+ if len(var_names_in_obj):
+ logger.warning(f'There are variable states missed in {filename}. '
+ f'The missed variables are: {var_names_in_obj}.')
+
+
+def _unique_and_duplicate(collector: dict):
+ gather = bm.ArrayCollector()
+ id2name = dict()
+ duplicates = ([], [])
+ for k, v in collector.items():
+ id_ = id(v)
+ if id_ not in id2name:
+ gather[k] = v
+ id2name[id_] = k
+ else:
+ k2 = id2name[id_]
+ duplicates[0].append(k)
+ duplicates[1].append(k2)
+ duplicates = (duplicates[0], duplicates[1])
+ return gather, duplicates
+
+
+def save_as_h5(filename: str, variables: dict):
+ """Save variables into a HDF5 file.
+
+ Parameters
+ ----------
+ filename: str
+ The filename to save.
+ variables: dict
+ All variables to save.
+ """
+ if not (filename.endswith('.hdf5') or filename.endswith('.h5')):
+ raise ValueError(f'Cannot save variables as a HDF5 file. We only support file with '
+ f'postfix of ".hdf5" and ".h5". But we got {filename}')
+
+ import h5py
+
+ # check variables
+ check_dict_data(variables, name='variables')
+ variables, duplicates = _unique_and_duplicate(variables)
+
+ # save
+ f = h5py.File(filename, "w")
+ for key, data in variables.items():
+ f[key] = bm.as_numpy(data)
+ if len(duplicates[0]):
+ f.create_dataset(DUPLICATE_TARGET, data='+'.join(duplicates[1]))
+ f.create_dataset(DUPLICATE_KEY, data='+'.join(duplicates[0]))
+ f.close()
+
+
+def load_by_h5(filename: str, target, verbose: bool = False):
+ """Load variables in a HDF5 file.
+
+ Parameters
+ ----------
+ filename: str
+ The filename to load variables.
+ target: BrainPyObject
+ The instance of :py:class:`~.brainpy.BrainPyObject`.
+ verbose: bool
+ Whether report the load progress.
+ """
+ if not (filename.endswith('.hdf5') or filename.endswith('.h5')):
+ raise ValueError(f'Cannot load variables from a HDF5 file. We only support file with '
+ f'postfix of ".hdf5" and ".h5". But we got {filename}')
+
+ # read data
+ import h5py
+ load_vars = dict()
+ with h5py.File(filename, "r") as f:
+ for key in f.keys():
+ if key in [DUPLICATE_KEY, DUPLICATE_TARGET]: continue
+ load_vars[key] = np.asarray(f[key])
+ if DUPLICATE_KEY in f:
+ duplicate_keys = np.asarray(f[DUPLICATE_KEY]).item().decode("utf-8").split('+')
+ duplicate_targets = np.asarray(f[DUPLICATE_TARGET]).item().decode("utf-8").split('+')
+ duplicates = (duplicate_keys, duplicate_targets)
+ else:
+ duplicates = ([], [])
+
+ # assign values
+ _load(target, verbose, filename, load_vars, duplicates)
+
+
+def save_as_npz(filename, variables, compressed=False):
+ """Save variables into a numpy file.
+
+ Parameters
+ ----------
+ filename: str
+ The filename to store.
+ variables: dict
+ Variables to save.
+ compressed: bool
+ Whether we use the compressed mode.
+ """
+ if not filename.endswith('.npz'):
+ raise ValueError(f'Cannot save variables as a .npz file. We only support file with '
+ f'postfix of ".npz". But we got {filename}')
+
+ check_dict_data(variables, name='variables')
+ variables, duplicates = _unique_and_duplicate(variables)
+
+ # save
+ variables = {k: bm.as_numpy(v) for k, v in variables.items()}
+ if len(duplicates[0]):
+ variables[DUPLICATE_KEY] = np.asarray(duplicates[0])
+ variables[DUPLICATE_TARGET] = np.asarray(duplicates[1])
+ if compressed:
+ np.savez_compressed(filename, **variables)
+ else:
+ np.savez(filename, **variables)
+
+
+def load_by_npz(filename, target, verbose=False):
+ """Load variables from a numpy file.
+
+ Parameters
+ ----------
+ filename: str
+ The filename to load variables.
+ target: BrainPyObject
+ The instance of :py:class:`~.brainpy.BrainPyObject`.
+ verbose: bool
+ Whether report the load progress.
+ """
+ if not filename.endswith('.npz'):
+ raise ValueError(f'Cannot load variables from a .npz file. We only support file with '
+ f'postfix of ".npz". But we got {filename}')
+
+ # load data
+ load_vars = dict()
+ all_data = np.load(filename)
+ for key in all_data.files:
+ if key in [DUPLICATE_KEY, DUPLICATE_TARGET]: continue
+ load_vars[key] = all_data[key]
+ if DUPLICATE_KEY in all_data:
+ duplicate_keys = all_data[DUPLICATE_KEY].tolist()
+ duplicate_targets = all_data[DUPLICATE_TARGET].tolist()
+ duplicates = (duplicate_keys, duplicate_targets)
+ else:
+ duplicates = ([], [])
+
+ # assign values
+ _load(target, verbose, filename, load_vars, duplicates)
+
+
+def save_as_pkl(filename, variables):
+ """Save variables into a pickle file.
+
+ Parameters
+ ----------
+ filename: str
+ The filename to save.
+ variables: dict
+ All variables to save.
+ """
+ if not (filename.endswith('.pkl') or filename.endswith('.pickle')):
+ raise ValueError(f'Cannot save variables into a pickle file. We only support file with '
+ f'postfix of ".pkl" and ".pickle". But we got {filename}')
+
+ check_dict_data(variables, name='variables')
+ variables, duplicates = _unique_and_duplicate(variables)
+ targets = {k: bm.as_numpy(v) for k, v in variables.items()}
+ if len(duplicates[0]) > 0:
+ targets[DUPLICATE_KEY] = np.asarray(duplicates[0])
+ targets[DUPLICATE_TARGET] = np.asarray(duplicates[1])
+ with open(filename, 'wb') as f:
+ pickle.dump(targets, f, protocol=pickle.HIGHEST_PROTOCOL)
+
+
+def load_by_pkl(filename, target, verbose=False):
+ """Load variables from a pickle file.
+
+ Parameters
+ ----------
+ filename: str
+ The filename to load variables.
+ target: BrainPyObject
+ The instance of :py:class:`~.brainpy.BrainPyObject`.
+ verbose: bool
+ Whether report the load progress.
+ """
+ if not (filename.endswith('.pkl') or filename.endswith('.pickle')):
+ raise ValueError(f'Cannot load variables from a pickle file. We only support file with '
+ f'postfix of ".pkl" and ".pickle". But we got {filename}')
+
+ # load variables
+ load_vars = dict()
+ with open(filename, 'rb') as f:
+ all_data = pickle.load(f)
+ for key, data in all_data.items():
+ if key in [DUPLICATE_KEY, DUPLICATE_TARGET]: continue
+ load_vars[key] = data
+ if DUPLICATE_KEY in all_data:
+ duplicate_keys = all_data[DUPLICATE_KEY].tolist()
+ duplicate_targets = all_data[DUPLICATE_TARGET].tolist()
+ duplicates = (duplicate_keys, duplicate_targets)
+ else:
+ duplicates = ([], [])
+
+ # assign data
+ _load(target, verbose, filename, load_vars, duplicates)
+
+
+def save_as_mat(filename, variables):
+ """Save variables into a matlab file.
+
+ Parameters
+ ----------
+ filename: str
+ The filename to save.
+ variables: dict
+ All variables to save.
+ """
+ if not filename.endswith('.mat'):
+ raise ValueError(f'Cannot save variables into a .mat file. We only support file with '
+ f'postfix of ".mat". But we got {filename}')
+
+ import scipy.io as sio
+
+ check_dict_data(variables, name='variables')
+ variables, duplicates = _unique_and_duplicate(variables)
+ variables = {k: np.expand_dims( bm.as_numpy(v), axis=0) for k, v in variables.items()}
+ if len(duplicates[0]):
+ variables[DUPLICATE_KEY] = np.expand_dims(np.asarray(duplicates[0]), axis=0)
+ variables[DUPLICATE_TARGET] = np.expand_dims(np.asarray(duplicates[1]), axis=0)
+ sio.savemat(filename, variables)
+
+
+def load_by_mat(filename, target, verbose=False):
+ """Load variables from a numpy file.
+
+ Parameters
+ ----------
+ filename: str
+ The filename to load variables.
+ target: BrainPyObject
+ The instance of :py:class:`~.brainpy.BrainPyObject`.
+ verbose: bool
+ Whether report the load progress.
+ """
+ if not filename.endswith('.mat'):
+ raise ValueError(f'Cannot load variables from a .mat file. We only support file with '
+ f'postfix of ".mat". But we got {filename}')
+
+ import scipy.io as sio
+
+ # load data
+ load_vars = dict()
+ all_data = sio.loadmat(filename)
+ for key, data in all_data.items():
+ if key.startswith('__'):
+ continue
+ if key in [DUPLICATE_KEY, DUPLICATE_TARGET]:
+ continue
+ load_vars[key] = data[0]
+ if DUPLICATE_KEY in all_data:
+ duplicate_keys = [a.strip() for a in all_data[DUPLICATE_KEY].tolist()[0]]
+ duplicate_targets = [a.strip() for a in all_data[DUPLICATE_TARGET].tolist()[0]]
+ duplicates = (duplicate_keys, duplicate_targets)
+ else:
+ duplicates = ([], [])
+
+ # assign values
+ _load(target, verbose, filename, load_vars, duplicates)
diff --git a/brainpy/checkpoints.py b/brainpy/checkpoints/serialization.py
similarity index 99%
rename from brainpy/checkpoints.py
rename to brainpy/checkpoints/serialization.py
index 47a474566..22a777a86 100644
--- a/brainpy/checkpoints.py
+++ b/brainpy/checkpoints/serialization.py
@@ -36,7 +36,6 @@
get_tensorstore_spec = None
from brainpy import math as bm
-from brainpy.base import Collector
from brainpy.errors import (AlreadyExistsError,
MPACheckpointingRequiredError,
MPARestoreTargetRequiredError,
@@ -264,7 +263,7 @@ def _restore_namedtuple(xs, state_dict: Dict[str, Any]):
register_serialization_state(bm.Array, _array_dict_state, _restore_array)
register_serialization_state(dict, _dict_state_dict, _restore_dict)
register_serialization_state(DotDict, _dict_state_dict, _restore_dict)
-register_serialization_state(Collector, _dict_state_dict, _restore_dict)
+register_serialization_state(bm.Collector, _dict_state_dict, _restore_dict)
register_serialization_state(list, _list_state_dict, _restore_list)
register_serialization_state(tuple,
_list_state_dict,
diff --git a/brainpy/base/tests/test_io.py b/brainpy/checkpoints/tests/test_io.py
similarity index 100%
rename from brainpy/base/tests/test_io.py
rename to brainpy/checkpoints/tests/test_io.py
diff --git a/brainpy/connect/base.py b/brainpy/connect/base.py
index 3282ea813..2766d34c7 100644
--- a/brainpy/connect/base.py
+++ b/brainpy/connect/base.py
@@ -20,7 +20,7 @@
# the connection dtypes
'set_default_dtype', 'MAT_DTYPE', 'IDX_DTYPE',
- # base class
+ # brainpy_object class
'Connector', 'TwoEndConnector', 'OneEndConnector',
# methods
@@ -133,11 +133,30 @@ def build_coo(self, ):
"""
- def __init__(self, ):
+ def __init__(
+ self,
+ pre: Union[int, Tuple[int, ...]] = None,
+ post: Union[int, Tuple[int, ...]] = None,
+ ):
self.pre_size = None
self.post_size = None
self.pre_num = None
self.post_num = None
+ if pre is not None:
+ if isinstance(pre, int):
+ pre = (pre,)
+ else:
+ pre = tuple(pre)
+ self.pre_size = pre
+ self.pre_num = tools.size2num(self.pre_size)
+ if post is not None:
+ if isinstance(post, int):
+ post = (post,)
+ else:
+ post = tuple(post)
+ self.post_size = post
+ self.post_num = tools.size2num(self.post_size)
+
def __repr__(self):
return self.__class__.__name__
@@ -303,13 +322,13 @@ def _make_returns(self, structures, conn_data):
if isinstance(conn_data, dict):
csr = conn_data.get('csr', None)
mat = conn_data.get('mat', None)
- coo = conn_data.get('coo', None)
+ coo = conn_data.get('coo', None) or conn_data.get('ij', None)
elif isinstance(conn_data, tuple):
if conn_data[0] == 'csr':
csr = conn_data[1]
elif conn_data[0] == 'mat':
mat = conn_data[1]
- elif conn_data[0] == 'coo':
+ elif conn_data[0] in ['coo', 'ij']:
coo = conn_data[1]
else:
raise ConnectorError(f'Must provide one of "csr", "mat" or "coo". Got "{conn_data[0]}" instead.')
@@ -541,8 +560,8 @@ def build_coo(self):
class OneEndConnector(TwoEndConnector):
"""Synaptic connector to build synapse connections within a population of neurons."""
- def __init__(self):
- super(OneEndConnector, self).__init__()
+ def __init__(self, *args, **kwargs):
+ super(OneEndConnector, self).__init__(*args, **kwargs)
def __call__(self, pre_size, post_size=None):
if post_size is None:
diff --git a/brainpy/connect/custom_conn.py b/brainpy/connect/custom_conn.py
index e23d330e5..54aa67b55 100644
--- a/brainpy/connect/custom_conn.py
+++ b/brainpy/connect/custom_conn.py
@@ -19,8 +19,8 @@
class MatConn(TwoEndConnector):
"""Connector built from the dense connection matrix."""
- def __init__(self, conn_mat):
- super(MatConn, self).__init__()
+ def __init__(self, conn_mat, **kwargs):
+ super(MatConn, self).__init__(**kwargs)
assert isinstance(conn_mat, (np.ndarray, bm.Array, jnp.ndarray)) and conn_mat.ndim == 2
self.pre_num, self.post_num = conn_mat.shape
@@ -42,8 +42,8 @@ def build_mat(self):
class IJConn(TwoEndConnector):
"""Connector built from the ``pre_ids`` and ``post_ids`` connections."""
- def __init__(self, i, j):
- super(IJConn, self).__init__()
+ def __init__(self, i, j, **kwargs):
+ super(IJConn, self).__init__(**kwargs)
assert isinstance(i, (np.ndarray, bm.Array, jnp.ndarray)) and i.ndim == 1
assert isinstance(j, (np.ndarray, bm.Array, jnp.ndarray)) and j.ndim == 1
@@ -78,8 +78,8 @@ def build_coo(self):
class CSRConn(TwoEndConnector):
"""Connector built from the CSR sparse connection matrix."""
- def __init__(self, indices, inptr):
- super(CSRConn, self).__init__()
+ def __init__(self, indices, inptr, **kwargs):
+ super(CSRConn, self).__init__(**kwargs)
self.indices = bm.asarray(indices, dtype=IDX_DTYPE)
self.inptr = bm.asarray(inptr, dtype=IDX_DTYPE)
@@ -99,7 +99,7 @@ def build_csr(self):
class SparseMatConn(CSRConn):
"""Connector built from the sparse connection matrix"""
- def __init__(self, csr_mat):
+ def __init__(self, csr_mat, **kwargs):
try:
from scipy.sparse import csr_matrix
except (ModuleNotFoundError, ImportError):
@@ -109,6 +109,7 @@ def __init__(self, csr_mat):
assert isinstance(csr_mat, csr_matrix)
self.csr_mat = csr_mat
super(SparseMatConn, self).__init__(indices=bm.asarray(self.csr_mat.indices, dtype=IDX_DTYPE),
- inptr=bm.asarray(self.csr_mat.indptr, dtype=IDX_DTYPE))
+ inptr=bm.asarray(self.csr_mat.indptr, dtype=IDX_DTYPE),
+ **kwargs)
self.pre_num = csr_mat.shape[0]
self.post_num = csr_mat.shape[1]
diff --git a/brainpy/connect/random_conn.py b/brainpy/connect/random_conn.py
index 767964a83..09e4b3783 100644
--- a/brainpy/connect/random_conn.py
+++ b/brainpy/connect/random_conn.py
@@ -45,8 +45,14 @@ class FixedProb(TwoEndConnector):
Seed the random generator.
"""
- def __init__(self, prob, pre_ratio=1., include_self=True, allow_multi_conn=False, seed=None):
- super(FixedProb, self).__init__()
+ def __init__(self,
+ prob,
+ pre_ratio=1.,
+ include_self=True,
+ allow_multi_conn=False,
+ seed=None,
+ **kwargs):
+ super(FixedProb, self).__init__(**kwargs)
assert 0. <= prob <= 1.
assert 0. <= pre_ratio <= 1.
self.prob = prob
@@ -139,8 +145,8 @@ class FixedTotalNum(TwoEndConnector):
The random number seed.
"""
- def __init__(self, num, seed=None):
- super(FixedTotalNum, self).__init__()
+ def __init__(self, num, seed=None, **kwargs):
+ super(FixedTotalNum, self).__init__(**kwargs)
if isinstance(num, int):
assert num >= 0, '"num" must be a non-negative integer.'
elif isinstance(num, float):
@@ -164,8 +170,13 @@ def __repr__(self):
class FixedNum(TwoEndConnector):
- def __init__(self, num, include_self=True, allow_multi_conn=False, seed=None):
- super(FixedNum, self).__init__()
+ def __init__(self,
+ num,
+ include_self=True,
+ allow_multi_conn=False,
+ seed=None,
+ **kwargs):
+ super(FixedNum, self).__init__(**kwargs)
if isinstance(num, int):
assert num >= 0, '"num" must be a non-negative integer.'
elif isinstance(num, float):
@@ -361,9 +372,10 @@ def __init__(
normalize: bool = True,
include_self: bool = True,
periodic_boundary: bool = False,
- seed: int = None
+ seed: int = None,
+ **kwargs
):
- super(GaussianProb, self).__init__()
+ super(GaussianProb, self).__init__(**kwargs)
self.sigma = sigma
self.encoding_values = encoding_values
self.normalize = normalize
@@ -478,9 +490,10 @@ def __init__(
prob,
directed=False,
include_self=False,
- seed=None
+ seed=None,
+ **kwargs
):
- super(SmallWorld, self).__init__()
+ super(SmallWorld, self).__init__(**kwargs)
self.prob = prob
self.directed = directed
self.num_neighbor = num_neighbor
@@ -610,8 +623,8 @@ class ScaleFreeBA(TwoEndConnector):
random networks", Science 286, pp 509-512, 1999.
"""
- def __init__(self, m, directed=False, seed=None):
- super(ScaleFreeBA, self).__init__()
+ def __init__(self, m, directed=False, seed=None, **kwargs):
+ super(ScaleFreeBA, self).__init__(**kwargs)
self.m = m
self.directed = directed
self.seed = format_seed(seed)
@@ -699,8 +712,8 @@ class ScaleFreeBADual(TwoEndConnector):
.. [1] N. Moshiri "The dual-Barabasi-Albert model", arXiv:1810.10538.
"""
- def __init__(self, m1, m2, p, directed=False, seed=None):
- super(ScaleFreeBADual, self).__init__()
+ def __init__(self, m1, m2, p, directed=False, seed=None, **kwargs):
+ super(ScaleFreeBADual, self).__init__(**kwargs)
self.m1 = m1
self.m2 = m2
self.p = p
@@ -812,8 +825,8 @@ class PowerLaw(TwoEndConnector):
Phys. Rev. E, 65, 026107, 2002.
"""
- def __init__(self, m, p, directed=False, seed=None):
- super(PowerLaw, self).__init__()
+ def __init__(self, m, p, directed=False, seed=None, **kwargs):
+ super(PowerLaw, self).__init__(**kwargs)
self.m = m
self.p = p
if self.p > 1 or self.p < 0:
@@ -910,8 +923,8 @@ class ProbDist(TwoEndConnector):
"""
- def __init__(self, dist=1, prob=1., pre_ratio=1., seed=None, include_self=True):
- super(ProbDist, self).__init__()
+ def __init__(self, dist=1, prob=1., pre_ratio=1., seed=None, include_self=True, **kwargs):
+ super(ProbDist, self).__init__(**kwargs)
self.prob = prob
self.pre_ratio = pre_ratio
@@ -1013,7 +1026,7 @@ def _connect_4d(pre_pos, pre_size, post_size, n_dim):
self._connect_3d = numba_jit(_connect_3d)
self._connect_4d = numba_jit(_connect_4d)
- def build_conn(self):
+ def build_coo(self):
if len(self.pre_size) != len(self.post_size):
raise ValueError('The dimensions of shapes of two objects to establish connections should '
f'be the same. But we got dimension {len(self.pre_size)} != {len(self.post_size)}. '
@@ -1046,4 +1059,4 @@ def build_conn(self):
pres, posts = f(pre_pos, pre_size=pre_size, post_size=post_size, n_dim=n_dim)
connected_pres.extend(pres)
connected_posts.extend(posts)
- return 'ij', (np.asarray(connected_pres), np.asarray(connected_posts))
+ return np.asarray(connected_pres), np.asarray(connected_posts)
diff --git a/brainpy/connect/regular_conn.py b/brainpy/connect/regular_conn.py
index 2a13dd7ae..f161cd377 100644
--- a/brainpy/connect/regular_conn.py
+++ b/brainpy/connect/regular_conn.py
@@ -22,8 +22,8 @@ class One2One(TwoEndConnector):
The two neuron groups should have the same size.
"""
- def __init__(self):
- super(One2One, self).__init__()
+ def __init__(self, *args, **kwargs):
+ super(One2One, self).__init__(*args, **kwargs)
def __call__(self, pre_size, post_size):
super(One2One, self).__call__(pre_size, post_size)
@@ -66,9 +66,9 @@ class All2All(TwoEndConnector):
will create (num_pre x num_post) synapses.
"""
- def __init__(self, include_self=True):
+ def __init__(self, *args, include_self=True, **kwargs):
self.include_self = include_self
- super(All2All, self).__init__()
+ super(All2All, self).__init__(*args, **kwargs)
def __repr__(self):
return f'{self.__class__.__name__}(include_self={self.include_self})'
@@ -100,8 +100,9 @@ def __init__(
strides,
include_self: bool = False,
periodic_boundary: bool = False,
+ **kwargs
):
- super(GridConn, self).__init__()
+ super(GridConn, self).__init__(**kwargs)
self.strides = strides
self.include_self = include_self
self.periodic_boundary = periodic_boundary
@@ -196,11 +197,13 @@ class GridFour(GridConn):
def __init__(
self,
include_self: bool = False,
- periodic_boundary: bool = False
+ periodic_boundary: bool = False,
+ **kwargs
):
super(GridFour, self).__init__(strides=np.asarray([-1, 0, 1]),
include_self=include_self,
- periodic_boundary=periodic_boundary)
+ periodic_boundary=periodic_boundary,
+ **kwargs)
self.include_self = include_self
self.periodic_boundary = periodic_boundary
@@ -244,11 +247,13 @@ def __init__(
self,
N: int = 1,
include_self: bool = False,
- periodic_boundary: bool = False
+ periodic_boundary: bool = False,
+ **kwargs
):
super(GridN, self).__init__(strides=np.arange(-N, N + 1, 1),
include_self=include_self,
- periodic_boundary=periodic_boundary)
+ periodic_boundary=periodic_boundary,
+ **kwargs)
self.N = N
def __repr__(self):
@@ -281,8 +286,11 @@ class GridEight(GridN):
.. versionadded:: 2.2.3.2
"""
- def __init__(self, include_self=False, periodic_boundary: bool = False):
- super(GridEight, self).__init__(N=1, include_self=include_self, periodic_boundary=periodic_boundary)
+ def __init__(self, include_self=False, periodic_boundary: bool = False, **kwargs):
+ super(GridEight, self).__init__(N=1,
+ include_self=include_self,
+ periodic_boundary=periodic_boundary,
+ **kwargs)
grid_eight = GridEight()
diff --git a/brainpy/dyn/_utils.py b/brainpy/dyn/_utils.py
index cd804c653..10144621c 100644
--- a/brainpy/dyn/_utils.py
+++ b/brainpy/dyn/_utils.py
@@ -3,16 +3,16 @@
from typing import Optional
import brainpy.math as bm
-from brainpy.base import BrainPyObject
-
__all__ = [
'get_output_var',
]
-def get_output_var(out_var: Optional[str],
- target: BrainPyObject) -> Optional[bm.Variable]:
+def get_output_var(
+ out_var: Optional[str],
+ target: bm.BrainPyObject
+) -> Optional[bm.Variable]:
if out_var is not None:
assert isinstance(out_var, str)
if not hasattr(target, out_var):
@@ -21,4 +21,3 @@ def get_output_var(out_var: Optional[str],
if not isinstance(out_var, bm.Variable):
raise ValueError(f'{target} does not has variable {out_var}')
return out_var
-
diff --git a/brainpy/dyn/base.py b/brainpy/dyn/base.py
index 2c349facd..ffb6fd539 100644
--- a/brainpy/dyn/base.py
+++ b/brainpy/dyn/base.py
@@ -9,8 +9,6 @@
from brainpy import tools, math as bm
from brainpy.algorithms import OnlineAlgorithm, OfflineAlgorithm
-from brainpy.base.base import BrainPyObject
-from brainpy.base.collector import Collector
from brainpy.connect import TwoEndConnector, MatConn, IJConn, One2One, All2All
from brainpy.errors import NoImplementationError, UnsupportedError
from brainpy.initialize import Initializer, parameter, variable, Uniform, noise as init_noise
@@ -45,7 +43,7 @@
SLICE_VARS = 'slice_vars'
-class DynamicalSystem(BrainPyObject):
+class DynamicalSystem(bm.BrainPyObject):
"""Base Dynamical System class.
.. note::
@@ -97,7 +95,7 @@ def __init__(
super(DynamicalSystem, self).__init__(name=name)
# local delay variables
- self.local_delay_vars: Dict[str, bm.LengthDelay] = Collector()
+ self.local_delay_vars: Dict[str, bm.LengthDelay] = bm.Collector()
# fitting parameters
self.online_fit_by = None
@@ -119,9 +117,9 @@ def mode(self, value):
def __repr__(self):
return f'{self.__class__.__name__}(name={self.name}, mode={self.mode})'
- def __call__(self, *args, **kwargs):
+ def __call__(self, shared: Dict, *args, **kwargs):
"""The shortcut to call ``update`` methods."""
- return self.update(*args, **kwargs)
+ return self.update(shared, *args, **kwargs)
def register_delay(
self,
@@ -377,12 +375,14 @@ class FuncAsDynSys(DynamicalSystem):
The computation mode.
"""
- def __init__(self,
- f: Callable,
- child_objs: Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]] = None,
- dyn_vars: Union[bm.Variable, Sequence[bm.Variable], Dict[str, bm.Variable]] = None,
- name: str = None,
- mode: bm.Mode = None):
+ def __init__(
+ self,
+ f: Callable,
+ child_objs: Union[bm.BrainPyObject, Sequence[bm.BrainPyObject], Dict[str, bm.BrainPyObject]] = None,
+ dyn_vars: Union[bm.Variable, Sequence[bm.Variable], Dict[str, bm.Variable]] = None,
+ name: Optional[str] = None,
+ mode: Optional[bm.Mode] = None
+ ):
super().__init__(name=name, mode=mode)
self._f = f
@@ -437,8 +437,8 @@ def __init__(
parent = DynamicalSystem
parent_name = DynamicalSystem.__name__
else:
- parent = BrainPyObject
- parent_name = BrainPyObject.__name__
+ parent = bm.BrainPyObject
+ parent_name = bm.BrainPyObject.__name__
# add tuple-typed components
for module in dynamical_systems_as_tuple:
@@ -567,8 +567,8 @@ def __init__(
):
self._modules = tuple(modules_as_tuple) + tuple(modules_as_dict.values())
- seq_modules = [m for m in modules_as_tuple if isinstance(m, BrainPyObject)]
- dict_modules = {k: m for k, m in modules_as_dict.items() if isinstance(m, BrainPyObject)}
+ seq_modules = [m for m in modules_as_tuple if isinstance(m, bm.BrainPyObject)]
+ dict_modules = {k: m for k, m in modules_as_dict.items() if isinstance(m, bm.BrainPyObject)}
super().__init__(*seq_modules,
name=name,
@@ -839,9 +839,9 @@ def __init__(
# pre or post neuron group
# ------------------------
- if not isinstance(pre, NeuGroup):
+ if not isinstance(pre, (NeuGroup, DynamicalSystem)):
raise TypeError('"pre" must be an instance of NeuGroup.')
- if not isinstance(post, NeuGroup):
+ if not isinstance(post, (NeuGroup, DynamicalSystem)):
raise TypeError('"post" must be an instance of NeuGroup.')
self.pre = pre
self.post = post
diff --git a/brainpy/dyn/channels/Ca.py b/brainpy/dyn/channels/Ca.py
index 8eb98b2e5..cf1bf1248 100644
--- a/brainpy/dyn/channels/Ca.py
+++ b/brainpy/dyn/channels/Ca.py
@@ -639,9 +639,9 @@ class ICaT_HM1992(ICa_p2q_ss):
T : float, ArrayType
The temperature.
T_base_p : float, ArrayType
- The base temperature factor of :math:`p` channel.
+ The brainpy_object temperature factor of :math:`p` channel.
T_base_q : float, ArrayType
- The base temperature factor of :math:`q` channel.
+ The brainpy_object temperature factor of :math:`q` channel.
g_max : float, ArrayType, Callable, Initializer
The maximum conductance.
V_sh : float, ArrayType, Callable, Initializer
@@ -736,9 +736,9 @@ class ICaT_HP1992(ICa_p2q_ss):
T : float, ArrayType
The temperature.
T_base_p : float, ArrayType
- The base temperature factor of :math:`p` channel.
+ The brainpy_object temperature factor of :math:`p` channel.
T_base_q : float, ArrayType
- The base temperature factor of :math:`q` channel.
+ The brainpy_object temperature factor of :math:`q` channel.
g_max : float, ArrayType, Callable, Initializer
The maximum conductance.
V_sh : float, ArrayType, Callable, Initializer
@@ -837,9 +837,9 @@ class ICaHT_HM1992(ICa_p2q_ss):
T : float, ArrayType
The temperature.
T_base_p : float, ArrayType
- The base temperature factor of :math:`p` channel.
+ The brainpy_object temperature factor of :math:`p` channel.
T_base_q : float, ArrayType
- The base temperature factor of :math:`q` channel.
+ The brainpy_object temperature factor of :math:`q` channel.
g_max : float, ArrayType, Initializer, Callable
The maximum conductance.
V_sh : float, ArrayType, Initializer, Callable
@@ -941,9 +941,9 @@ class ICaHT_Re1993(ICa_p2q_markov):
T : float, ArrayType
The temperature.
T_base_p : float, ArrayType
- The base temperature factor of :math:`p` channel.
+ The brainpy_object temperature factor of :math:`p` channel.
T_base_q : float, ArrayType
- The base temperature factor of :math:`q` channel.
+ The brainpy_object temperature factor of :math:`q` channel.
phi_p : optional, float, ArrayType, Callable, Initializer
The temperature factor for channel :math:`p`.
If `None`, :math:`\phi_p = \mathrm{T_base_p}^{\frac{T-23}{10}}`.
@@ -1029,9 +1029,9 @@ class ICaL_IS2008(ICa_p2q_ss):
T : float
The temperature.
T_base_p : float
- The base temperature factor of :math:`p` channel.
+ The brainpy_object temperature factor of :math:`p` channel.
T_base_q : float
- The base temperature factor of :math:`q` channel.
+ The brainpy_object temperature factor of :math:`q` channel.
g_max : float
The maximum conductance.
V_sh : float
diff --git a/brainpy/dyn/channels/K.py b/brainpy/dyn/channels/K.py
index 13e57ebb7..11915ae63 100644
--- a/brainpy/dyn/channels/K.py
+++ b/brainpy/dyn/channels/K.py
@@ -145,7 +145,7 @@ class IKDR_Ba2002(IK_p4_markov):
E : float, ArrayType, Initializer, Callable
The reversal potential (mV).
T_base : float, ArrayType
- The base of temperature factor.
+ The brainpy_object of temperature factor.
T : float, ArrayType, Initializer, Callable
The temperature (Celsius, :math:`^{\circ}C`).
V_sh : float, ArrayType, Initializer, Callable
diff --git a/brainpy/dyn/channels/__init__.py b/brainpy/dyn/channels/__init__.py
index 7b16d5c8a..326e68b12 100644
--- a/brainpy/dyn/channels/__init__.py
+++ b/brainpy/dyn/channels/__init__.py
@@ -1,5 +1,9 @@
# -*- coding: utf-8 -*-
+"""
+
+Access through ``brainpy.channels``.
+"""
from . import base, Ca, IH, K, Na, KCa, leaky
diff --git a/brainpy/dyn/channels/base.py b/brainpy/dyn/channels/base.py
index 3987cce1e..dc63e73fa 100644
--- a/brainpy/dyn/channels/base.py
+++ b/brainpy/dyn/channels/base.py
@@ -62,7 +62,7 @@ def __repr__(self):
class Calcium(Ion, Container):
- """The base calcium dynamics.
+ """The brainpy_object calcium dynamics.
Parameters
----------
diff --git a/brainpy/dyn/layers/dropout.py b/brainpy/dyn/layers/dropout.py
index d2ac998c8..6cbbde346 100644
--- a/brainpy/dyn/layers/dropout.py
+++ b/brainpy/dyn/layers/dropout.py
@@ -44,7 +44,7 @@ def __init__(
):
super(Dropout, self).__init__(mode=mode, name=name)
self.prob = prob
- self.rng = bm.random.RandomState(seed)
+ self.rng = bm.random.get_rng(seed)
def update(self, sha, x):
if sha.get('fit', True):
diff --git a/brainpy/dyn/layers/reservoir.py b/brainpy/dyn/layers/reservoir.py
index 946bef31e..ab4cfbac7 100644
--- a/brainpy/dyn/layers/reservoir.py
+++ b/brainpy/dyn/layers/reservoir.py
@@ -120,7 +120,7 @@ def __init__(
self.activation = bm.activations.get(activation)
self.activation_type = activation_type
is_string(activation_type, 'activation_type', ['internal', 'external'])
- self.rng = bm.random.RandomState(seed)
+ self.rng = bm.random.get_rng(seed)
is_float(spectral_radius, 'spectral_radius', allow_none=True)
self.spectral_radius = spectral_radius
diff --git a/brainpy/dyn/layers/tests/test_conv.py b/brainpy/dyn/layers/tests/test_conv.py
index b6be8b710..e25c6fa2c 100644
--- a/brainpy/dyn/layers/tests/test_conv.py
+++ b/brainpy/dyn/layers/tests/test_conv.py
@@ -7,7 +7,7 @@
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
-
+bp.math.random.seed()
class TestConv(TestCase):
def test_Conv2D_img(self):
diff --git a/brainpy/dyn/layers/tests/test_pooling.py b/brainpy/dyn/layers/tests/test_pooling.py
index 04e6ce106..f6c6ccbaf 100644
--- a/brainpy/dyn/layers/tests/test_pooling.py
+++ b/brainpy/dyn/layers/tests/test_pooling.py
@@ -7,6 +7,7 @@
import brainpy as bp
import brainpy.math as bm
+bm.random.seed()
class TestPool(TestCase):
diff --git a/brainpy/dyn/neurons/input_groups.py b/brainpy/dyn/neurons/input_groups.py
index 14b239140..85b7f6ef2 100644
--- a/brainpy/dyn/neurons/input_groups.py
+++ b/brainpy/dyn/neurons/input_groups.py
@@ -191,14 +191,14 @@ def __init__(
# variables
self.spike = variable_(lambda s: bm.zeros(s, dtype=bool), self.varshape, self.mode)
- self.rng = bm.random.RandomState(seed)
+ self.rng = bm.random.get_rng(seed)
def update(self, tdi, x=None):
shape = (self.spike.shape[:1] + self.varshape) if isinstance(self.mode, bm.BatchingMode) else self.varshape
self.spike.update(self.rng.random(shape) <= (self.freqs * tdi['dt'] / 1000.))
def reset(self, batch_size=None):
- self.rng.seed(self.seed)
+ self.rng.value = bm.random.get_rng(self.seed)
self.reset_state(batch_size)
def reset_state(self, batch_size=None):
diff --git a/brainpy/dyn/neurons/reduced_models.py b/brainpy/dyn/neurons/reduced_models.py
index 465fd910a..541090531 100644
--- a/brainpy/dyn/neurons/reduced_models.py
+++ b/brainpy/dyn/neurons/reduced_models.py
@@ -87,7 +87,7 @@ def __init__(
mode=mode,
keep_size=keep_size,
name=name)
- is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode), self.__class__)
+ is_subclass(self.mode, (bm.TrainingMode, bm.NonBatchingMode))
# parameters
self.V_rest = parameter(V_rest, self.varshape, allow_none=False)
diff --git a/brainpy/dyn/rates/populations.py b/brainpy/dyn/rates/populations.py
index b04ef65d6..e1dd57535 100644
--- a/brainpy/dyn/rates/populations.py
+++ b/brainpy/dyn/rates/populations.py
@@ -945,7 +945,7 @@ def __init__(
self.Ie = variable(bm.zeros, self.mode, self.varshape) # Input of excitaory population
self.Ii = variable(bm.zeros, self.mode, self.varshape) # Input of inhibitory population
if bm.any(self.noise_e != 0) or bm.any(self.noise_i != 0):
- self.rng = bm.random.RandomState(self.seed)
+ self.rng = bm.random.get_rng(seed)
def reset(self, batch_size=None):
self.rng.seed(self.seed)
diff --git a/brainpy/dyn/runners.py b/brainpy/dyn/runners.py
index 04ca9e400..fe61d8937 100644
--- a/brainpy/dyn/runners.py
+++ b/brainpy/dyn/runners.py
@@ -16,7 +16,7 @@
from brainpy import math as bm, tools
from brainpy.check import is_float, serialize_kwargs
from brainpy.dyn.base import DynamicalSystem
-from brainpy.errors import RunningError
+from brainpy.errors import RunningError, NoLongerSupportError
from brainpy.running.runner import Runner
from brainpy.types import ArrayType, Output, Monitor
@@ -28,6 +28,10 @@
SUPPORTED_INPUT_TYPE = ['fix', 'iter', 'func']
+def _is_brainpy_array(x):
+ return isinstance(x, bm.Array)
+
+
def check_and_format_inputs(host, inputs):
"""Check inputs and get the formatted inputs for the given population.
@@ -289,10 +293,10 @@ class DSRunner(Runner):
numpy_mon_after_run : bool
When finishing the network running, transform the JAX arrays into numpy ndarray or not?
- time_major: bool
+ data_first_axis: str
Set the default data dimension arrangement.
- To indicate whether the first axis is the batch size (``time_major=False``) or the
- time length (``time_major=True``).
+ To indicate whether the first axis is the batch size (``data_first_axis='B'``) or the
+ time length (``data_first_axis='T'``).
In order to be compatible with previous API, default is set to be ``False``.
.. versionadded:: 2.3.1
@@ -305,25 +309,25 @@ def __init__(
target: DynamicalSystem,
# inputs for target variables
- inputs: Sequence = (),
- time_major: bool = False,
+ inputs: Union[Sequence, Callable] = (),
# monitors
- monitors: Union[Sequence, Dict] = None,
+ monitors: Optional[Union[Sequence, Dict]] = None,
numpy_mon_after_run: bool = True,
# jit
jit: Union[bool, Dict[str, bool]] = True,
- dyn_vars: Union[bm.Variable, Sequence[bm.Variable], Dict[str, bm.Variable]] = None,
+ dyn_vars: Optional[Union[bm.Variable, Sequence[bm.Variable], Dict[str, bm.Variable]]] = None,
# extra info
- dt: float = None,
+ dt: Optional[float] = None,
t0: Union[float, int] = 0.,
progress_bar: bool = True,
+ data_first_axis: Optional[str] = None,
# deprecated
- fun_inputs: Callable = None,
- fun_monitors: Dict[str, Callable] = None,
+ fun_inputs: Optional[Callable] = None,
+ fun_monitors: Optional[Dict[str, Callable]] = None,
):
if not isinstance(target, DynamicalSystem):
raise RunningError(f'"target" must be an instance of {DynamicalSystem.__name__}, '
@@ -341,7 +345,10 @@ def __init__(
self._t0 = t0
self.i0 = bm.Variable(bm.asarray([1], dtype=bm.int_))
self.t0 = bm.Variable(bm.asarray([t0], dtype=bm.float_))
- self.time_major = time_major
+ if data_first_axis is None:
+ data_first_axis = 'B' if isinstance(self.target.mode, bm.BatchingMode) else 'T'
+ assert data_first_axis in ['B', 'T']
+ self.data_first_axis = data_first_axis
# parameters
dt = bm.get_dt() if dt is None else dt
@@ -369,12 +376,12 @@ def __repr__(self):
return (f'{name}(target={tools.repr_context(str(self.target), indent2)}, \n'
f'{indent}jit={self.jit},\n'
f'{indent}dt={self.dt},\n'
- f'{indent}time_major={self.time_major})')
+ f'{indent}data_first_axis={self.data_first_axis})')
def reset_state(self):
"""Reset state of the ``DSRunner``."""
- self.i0[0] = 0
- self.t0[0] = self._t0
+ self.i0.value = bm.zeros_like(self.i0)
+ self.t0.value = bm.ones_like(self.t0) * self._t0
def predict(
self,
@@ -404,8 +411,8 @@ def predict(
- If the mode of ``target`` is instance of :py:class:`~.BatchingMode`,
``inputs`` must be a PyTree of data with two dimensions:
- ``(batch, time, ...)`` when ``time_major=False``,
- or ``(time, batch, ...)`` when ``time_major=True``.
+ ``(batch, time, ...)`` when ``data_first_axis='B'``,
+ or ``(time, batch, ...)`` when ``data_first_axis='T'``.
- If the mode of ``target`` is instance of :py:class:`~.NonBatchingMode`,
the ``inputs`` should be a PyTree of data with one dimension:
``(time, ...)``.
@@ -429,7 +436,7 @@ def predict(
"""
if inputs_are_batching is not None:
- raise ValueError(
+ raise NoLongerSupportError(
f'''
`inputs_are_batching` is no longer supported.
The target mode of {self.target.mode} has already indicated the input should be batching.
@@ -459,7 +466,7 @@ def predict(
shared['i'] += self.i0
shared['t'] += self.t0
- if isinstance(self.target.mode, bm.BatchingMode) and not self.time_major:
+ if isinstance(self.target.mode, bm.BatchingMode) and self.data_first_axis == 'B':
inputs = tree_map(lambda x: bm.moveaxis(x, 0, 1),
inputs,
is_leaf=lambda x: isinstance(x, bm.Array))
@@ -527,7 +534,7 @@ def _predict(
"""
_predict_func = self._get_f_predict(shared_args)
outs_and_mons = _predict_func(xs)
- if isinstance(self.target.mode, bm.BatchingMode) and not self.time_major:
+ if isinstance(self.target.mode, bm.BatchingMode) and self.data_first_axis == 'B':
outs_and_mons = tree_map(lambda x: bm.moveaxis(x, 0, 1),
outs_and_mons,
is_leaf=lambda x: isinstance(x, bm.Array))
@@ -569,8 +576,10 @@ def _get_input_batch_size(self, xs=None) -> Optional[int]:
return None
if isinstance(self.target.mode, bm.NonBatchingMode):
return None
- leaves, _ = tree_flatten(xs, is_leaf=lambda x: isinstance(x, bm.Array))
- if self.time_major:
+ if isinstance(xs, (bm.Array, jax.Array, np.ndarray)):
+ return xs.shape[1] if self.data_first_axis == 'T' else xs.shape[0]
+ leaves, _ = tree_flatten(xs, is_leaf=_is_brainpy_array)
+ if self.data_first_axis == 'T':
num_batch_sizes = [x.shape[1] for x in leaves]
else:
num_batch_sizes = [x.shape[0] for x in leaves]
@@ -585,13 +594,13 @@ def _get_input_time_step(self, duration=None, xs=None) -> int:
return int(duration / self.dt)
if xs is not None:
if isinstance(xs, (bm.Array, jnp.ndarray)):
- return xs.shape[0] if self.time_major else xs.shape[1]
+ return xs.shape[0] if self.data_first_axis == 'T' else xs.shape[1]
else:
leaves, _ = tree_flatten(xs, is_leaf=lambda x: isinstance(x, bm.Array))
- if self.time_major:
- num_steps = [val.shape[0] for val in leaves]
+ if self.data_first_axis == 'T':
+ num_steps = [x.shape[0] for x in leaves]
else:
- num_steps = [val.shape[1] for val in leaves]
+ num_steps = [x.shape[1] for x in leaves]
if len(set(num_steps)) != 1:
raise ValueError(f'Number of time step is different across arrays in '
f'the provided "xs". We got {set(num_steps)}.')
@@ -612,6 +621,7 @@ def _step_func_predict(self, shared_args, t, i, x):
out = self.target(*args)
# monitor step
+ shared['t'] += self.dt
mon = self._step_func_monitor(shared)
# finally
@@ -625,8 +635,10 @@ def _get_f_predict(self, shared_args: Dict = None):
shared_kwargs_str = serialize_kwargs(shared_args)
if shared_kwargs_str not in self._f_predict_compiled:
- dyn_vars = self.vars().unique()
- dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
+ dyn_vars = self.target.vars()
+ dyn_vars.update(self._dyn_vars)
+ dyn_vars.update(self.vars(level=0))
+ dyn_vars = dyn_vars.unique()
def run_func(all_inputs):
with jax.disable_jit(not self.jit['predict']):
diff --git a/brainpy/dyn/synapses/abstract_models.py b/brainpy/dyn/synapses/abstract_models.py
index dd804633c..6337d1dcb 100644
--- a/brainpy/dyn/synapses/abstract_models.py
+++ b/brainpy/dyn/synapses/abstract_models.py
@@ -280,7 +280,7 @@ def __init__(
pre: NeuGroup,
post: NeuGroup,
conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]],
- output: SynOut = CUBA(),
+ output: Optional[SynOut] = CUBA(),
stp: Optional[SynSTP] = None,
comp_method: str = 'sparse',
g_max: Union[float, ArrayType, Initializer, Callable] = 1.,
@@ -973,7 +973,7 @@ def __init__(
self.freq = freq
self.weight = weight
self.seed = seed
- self.rng = bm.random.RandomState(self.seed)
+ self.rng = bm.random.get_rng(seed)
def update(self, tdi):
p = self.freq * tdi.dt / 1e3
diff --git a/brainpy/dyn/transform.py b/brainpy/dyn/transform.py
index 875baf41c..6becf3aa7 100644
--- a/brainpy/dyn/transform.py
+++ b/brainpy/dyn/transform.py
@@ -6,7 +6,6 @@
from jax.tree_util import tree_flatten, tree_unflatten, tree_map
from brainpy import tools, math as bm
-from brainpy.base.base import BrainPyObject
from brainpy.check import is_float
from brainpy.types import PyTree
from .base import DynamicalSystem, Sequential
@@ -17,7 +16,7 @@
]
-class DynSysToBPObj(BrainPyObject):
+class DynSysToBPObj(bm.BrainPyObject):
"""Transform a :py:class:`DynamicalSystem` to a :py:class:`BrainPyObject`.
Parameters
@@ -72,13 +71,13 @@ class LoopOverTime(DynSysToBPObj):
>>> over_time.reset_state(n_batch)
(30, 128, 2)
>>>
- >>> hist_l3 = over_time(bm.random.rand(n_time, n_batch, n_in), time_major=True)
+ >>> hist_l3 = over_time(bm.random.rand(n_time, n_batch, n_in), data_first_axis='T')
>>> print(hist_l3.shape)
>>>
>>> # monitor the "l1" layer state
>>> over_time = bp.LoopOverTime(model, out_vars=model['l1'].state)
>>> over_time.reset_state(n_batch)
- >>> hist_l3, hist_l1 = over_time(bm.random.rand(n_time, n_batch, n_in), time_major=True)
+ >>> hist_l3, hist_l1 = over_time(bm.random.rand(n_time, n_batch, n_in), data_first_axis='T')
>>> print(hist_l3.shape)
(30, 128, 2)
>>> print(hist_l1.shape)
@@ -149,7 +148,7 @@ def __call__(
t0: float = 0.,
dt: Optional[float] = None,
shared_arg: Optional[Dict] = None,
- time_major: bool = True
+ data_first_axis: str = 'T'
):
"""Forward propagation along the time or inputs.
@@ -165,7 +164,7 @@ def __call__(
shared_arg: dict
The shared arguments across the nodes.
For instance, `shared_arg={'fit': False}` for the prediction phase.
- time_major: bool
+ data_first_axis: str
Denote whether the input data is time major.
If so, we treat the data as `(time, batch, ...)` when the `target` is in Batching mode.
Default is True.
@@ -175,6 +174,8 @@ def __call__(
out: PyTree
The accumulated outputs over time.
"""
+ assert data_first_axis in ['B', 'T']
+
is_float(t0, 't0')
is_float(dt, 'dt', allow_none=True)
dt = bm.get_dt() if dt is None else dt
@@ -195,11 +196,11 @@ def __call__(
else:
inp_err_msg = ('\n'
'Input should be a Array PyTree with the shape '
- 'of (B, T, ...) or (T, B, ...) with `time_major=True`, '
+ 'of (B, T, ...) or (T, B, ...) with `data_first_axis="T"`, '
'where B the batch size and T the time length.')
xs, tree = tree_flatten(duration_or_xs, lambda a: isinstance(a, bm.Array))
if isinstance(self.target.mode, bm.BatchingMode):
- b_idx, t_idx = (1, 0) if time_major else (0, 1)
+ b_idx, t_idx = (1, 0) if data_first_axis == 'T' else (0, 1)
try:
batch = tuple(set([x.shape[b_idx] for x in xs]))
@@ -225,10 +226,10 @@ def __call__(
if self.no_state:
xs = [jnp.reshape(x, (length[0] * batch[0],) + x.shape[2:]) for x in xs]
else:
- if not time_major:
+ if data_first_axis == 'B':
xs = [jnp.moveaxis(x, 0, 1) for x in xs]
xs = tree_unflatten(tree, xs)
- origin_shape = (length[0], batch[0]) if time_major else (batch[0], length[0])
+ origin_shape = (length[0], batch[0]) if data_first_axis == 'T' else (batch[0], length[0])
else:
diff --git a/brainpy/encoding/base.py b/brainpy/encoding/base.py
index 84e19c584..04c32f1b0 100644
--- a/brainpy/encoding/base.py
+++ b/brainpy/encoding/base.py
@@ -1,14 +1,13 @@
# -*- coding: utf-8 -*-
-
-from brainpy.base.base import BrainPyObject
+import brainpy.math as bm
__all__ = [
'Encoder'
]
-class Encoder(BrainPyObject):
+class Encoder(bm.BrainPyObject):
"""Base class for encoding rate values as spike trains."""
def __call__(self, *args, **kwargs):
raise NotImplementedError
diff --git a/brainpy/encoding/stateless_encoding.py b/brainpy/encoding/stateless_encoding.py
index b7b31a4aa..161675605 100644
--- a/brainpy/encoding/stateless_encoding.py
+++ b/brainpy/encoding/stateless_encoding.py
@@ -41,7 +41,7 @@ def __init__(self,
check.is_float(max_val, 'max_val')
self.min_val = min_val
self.max_val = max_val
- self.rng = bm.random.RandomState(seed)
+ self.rng = bm.random.get_rng(seed)
def __call__(self, x: ArrayType, num_step: int = None):
"""
diff --git a/brainpy/experimental/__init__.py b/brainpy/experimental/__init__.py
new file mode 100644
index 000000000..efaba3f24
--- /dev/null
+++ b/brainpy/experimental/__init__.py
@@ -0,0 +1,3 @@
+# -*- coding: utf-8 -*-
+
+from .synapses import *
diff --git a/brainpy/experimental/synapses.py b/brainpy/experimental/synapses.py
new file mode 100644
index 000000000..2a68eb96a
--- /dev/null
+++ b/brainpy/experimental/synapses.py
@@ -0,0 +1,122 @@
+# -*- coding: utf-8 -*-
+
+from typing import Union, Optional
+
+import brainpylib as bl
+import jax
+
+from brainpy import (math as bm,
+ initialize as init,
+ connect)
+from brainpy.dyn.base import DynamicalSystem, SynSTP
+from brainpy.integrators.ode import odeint
+from brainpy.types import Initializer, ArrayType
+
+__all__ = [
+ 'Exponential',
+]
+
+
+class Exponential(DynamicalSystem):
+ def __init__(
+ self,
+ conn: connect.TwoEndConnector,
+ stp: Optional[SynSTP] = None,
+ g_max: Union[float, Initializer] = 1.,
+ g_initializer: Union[float, Initializer] = init.ZeroInit(),
+ tau: Union[float, ArrayType] = 8.0,
+ method: str = 'exp_auto',
+ mode: Optional[bm.Mode] = None,
+ name: Optional[str] = None,
+ ):
+ super(Exponential, self).__init__(name=name, mode=mode)
+
+ # component
+ self.conn = conn
+ self.stp = stp
+ self.g_initializer = g_initializer
+ assert self.conn.pre_num is not None
+ assert self.conn.post_num is not None
+
+ # parameters
+ self.tau = tau
+ if bm.size(self.tau) != 1:
+ raise ValueError(f'"tau" must be a scalar or a tensor with size of 1. But we got {self.tau}')
+
+ # connections and weights
+ if isinstance(self.conn, connect.One2One):
+ self.g_max = init.parameter(g_max, (self.conn.pre_num,), allow_none=False)
+
+ elif isinstance(self.conn, connect.All2All):
+ self.g_max = init.parameter(g_max, (self.conn.pre_num, self.conn.post_num), allow_none=False)
+
+ else:
+ self.conn_mask = self.conn.require('pre2post')
+ self.g_max = init.parameter(g_max, self.conn_mask[0].shape, allow_none=False)
+
+ # variables
+ self.g = init.variable_(g_initializer, self.conn.post_num, self.mode)
+
+ # function
+ self.integral = odeint(lambda g, t: -g / self.tau, method=method)
+
+ def reset_state(self, batch_size=None):
+ self.g.value = init.variable_(bm.zeros, self.conn.post_num, batch_size)
+ if self.stp is not None:
+ self.stp.reset_state(batch_size)
+
+ def _syn2post_with_one2one(self, syn_value, syn_weight):
+ return syn_value * syn_weight
+
+ def _syn2post_with_all2all(self, syn_value, syn_weight):
+ if bm.ndim(syn_weight) == 0:
+ if isinstance(self.mode, bm.BatchingMode):
+ assert syn_value.ndim == 2
+ post_vs = bm.sum(syn_value, keepdims=True, axis=1)
+ else:
+ post_vs = bm.sum(syn_value)
+ if not self.conn.include_self:
+ post_vs = post_vs - syn_value
+ post_vs = syn_weight * post_vs
+ else:
+ assert syn_weight.ndim == 2
+ if isinstance(self.mode, bm.BatchingMode):
+ assert syn_value.ndim == 2
+ post_vs = syn_value @ syn_weight
+ else:
+ post_vs = syn_value @ syn_weight
+ return post_vs
+
+ def update(self, tdi, spike):
+ t, dt = tdi['t'], tdi.get('dt', bm.dt)
+
+ # update sub-components
+ if self.stp is not None:
+ self.stp.update(tdi, spike)
+
+ # post values
+ if isinstance(self.conn, connect.All2All):
+ syn_value = bm.asarray(spike, dtype=bm.float_)
+ if self.stp is not None:
+ syn_value = self.stp(syn_value)
+ post_vs = self._syn2post_with_all2all(syn_value, self.g_max)
+ elif isinstance(self.conn, connect.One2One):
+ syn_value = bm.asarray(spike, dtype=bm.float_)
+ if self.stp is not None:
+ syn_value = self.stp(syn_value)
+ post_vs = self._syn2post_with_one2one(syn_value, self.g_max)
+ else:
+ if isinstance(self.mode, bm.BatchingMode):
+ f = jax.vmap(bl.event_ops.event_csr_matvec, in_axes=(None, None, None, 0))
+ post_vs = f(self.g_max, self.conn_mask[0], self.conn_mask[1], spike,
+ shape=(self.conn.pre_num, self.conn.post_num), transpose=True)
+ else:
+ post_vs = bl.event_ops.event_csr_matvec(
+ self.g_max, self.conn_mask[0], self.conn_mask[1], spike,
+ shape=(self.conn.pre_num, self.conn.post_num), transpose=True
+ )
+ # updates
+ self.g.value = self.integral(self.g.value, t, dt) + post_vs
+
+ # output
+ return self.g.value
diff --git a/brainpy/initialize/generic.py b/brainpy/initialize/generic.py
index 79cc40e12..1b9e6ec92 100644
--- a/brainpy/initialize/generic.py
+++ b/brainpy/initialize/generic.py
@@ -94,16 +94,16 @@ def init_param(
def variable_(
- data: Union[Callable, ArrayType],
+ init: Union[Callable, ArrayType],
size: Shape = None,
batch_size_or_mode: Optional[Union[int, bool, bm.Mode]] = None,
batch_axis: int = 0,
):
- """Initialize variables. Same as `variable()`.
+ """Initialize a :math:`~.Variable` from a callable function or a data.
Parameters
----------
- data: callable, function, ArrayType
+ init: callable, function, ArrayType
The data to be initialized as a ``Variable``.
batch_size_or_mode: int, bool, Mode, optional
The batch size, model ``Mode``, boolean state.
@@ -125,11 +125,11 @@ def variable_(
variable, parameter, noise, delay
"""
- return variable(data, batch_size_or_mode, size, batch_axis)
+ return variable(init, batch_size_or_mode, size, batch_axis)
def variable(
- data: Union[Callable, ArrayType],
+ init: Union[Callable, ArrayType],
batch_size_or_mode: Optional[Union[int, bool, bm.Mode]] = None,
size: Shape = None,
batch_axis: int = 0,
@@ -138,7 +138,7 @@ def variable(
Parameters
----------
- data: callable, function, ArrayType
+ init: callable, function, ArrayType
The data to be initialized as a ``Variable``.
batch_size_or_mode: int, bool, Mode, optional
The batch size, model ``Mode``, boolean state.
@@ -161,34 +161,34 @@ def variable(
"""
size = to_size(size)
- if callable(data):
+ if callable(init):
if size is None:
raise ValueError('"varshape" cannot be None when data is a callable function.')
if isinstance(batch_size_or_mode, bm.NonBatchingMode):
- return bm.Variable(data(size))
+ return bm.Variable(init(size))
elif isinstance(batch_size_or_mode, bm.BatchingMode):
new_shape = size[:batch_axis] + (1,) + size[batch_axis:]
- return bm.Variable(data(new_shape), batch_axis=batch_axis)
+ return bm.Variable(init(new_shape), batch_axis=batch_axis)
elif batch_size_or_mode in (None, False):
- return bm.Variable(data(size))
+ return bm.Variable(init(size))
elif isinstance(batch_size_or_mode, int):
new_shape = size[:batch_axis] + (int(batch_size_or_mode),) + size[batch_axis:]
- return bm.Variable(data(new_shape), batch_axis=batch_axis)
+ return bm.Variable(init(new_shape), batch_axis=batch_axis)
else:
raise ValueError('Unknown batch_size_or_mode.')
else:
if size is not None:
- if bm.shape(data) != size:
- raise ValueError(f'The shape of "data" {bm.shape(data)} does not match with "var_shape" {size}')
+ if bm.shape(init) != size:
+ raise ValueError(f'The shape of "data" {bm.shape(init)} does not match with "var_shape" {size}')
if isinstance(batch_size_or_mode, bm.NonBatchingMode):
- return bm.Variable(data)
+ return bm.Variable(init)
elif isinstance(batch_size_or_mode, bm.BatchingMode):
- return bm.Variable(bm.expand_dims(data, axis=batch_axis), batch_axis=batch_axis)
+ return bm.Variable(bm.expand_dims(init, axis=batch_axis), batch_axis=batch_axis)
elif batch_size_or_mode in (None, False):
- return bm.Variable(data)
+ return bm.Variable(init)
elif isinstance(batch_size_or_mode, int):
- return bm.Variable(bm.repeat(bm.expand_dims(data, axis=batch_axis),
+ return bm.Variable(bm.repeat(bm.expand_dims(init, axis=batch_axis),
int(batch_size_or_mode),
axis=batch_axis),
batch_axis=batch_axis)
diff --git a/brainpy/initialize/random_inits.py b/brainpy/initialize/random_inits.py
index 39ed3b2a4..e80c03cf9 100644
--- a/brainpy/initialize/random_inits.py
+++ b/brainpy/initialize/random_inits.py
@@ -54,7 +54,7 @@ def __init__(self, mean=0., scale=1., seed=None):
super(Normal, self).__init__()
self.scale = scale
self.mean = mean
- self.rng = bm.random.DEFAULT if seed is None else bm.random.RandomState(seed)
+ self.rng = bm.random.get_rng(seed, clone=False)
def __call__(self, *shape, dtype=None):
shape = _format_shape(shape)
@@ -80,7 +80,7 @@ def __init__(self, min_val: float = 0., max_val: float = 1., seed=None):
super(Uniform, self).__init__()
self.min_val = min_val
self.max_val = max_val
- self.rng = bm.random.DEFAULT if seed is None else bm.random.RandomState(seed=seed)
+ self.rng = bm.random.get_rng(seed, clone=False)
def __call__(self, shape, dtype=None):
shape = _format_shape(shape)
@@ -107,7 +107,7 @@ def __init__(
self.in_axis = in_axis
self.out_axis = out_axis
self.distribution = distribution
- self.rng = bm.random.DEFAULT if seed is None else bm.random.RandomState(seed)
+ self.rng = bm.random.get_rng(seed, clone=False)
def __call__(self, shape, dtype=None):
shape = _format_shape(shape)
@@ -264,7 +264,7 @@ def __init__(
super(Orthogonal, self).__init__()
self.scale = scale
self.axis = axis
- self.rng = bm.random.DEFAULT if seed is None else bm.random.RandomState(seed)
+ self.rng = bm.random.get_rng(seed, clone=False)
def __call__(self, shape, dtype=None):
shape = _format_shape(shape)
diff --git a/brainpy/inputs/currents.py b/brainpy/inputs/currents.py
index 5fa77bbf6..418908583 100644
--- a/brainpy/inputs/currents.py
+++ b/brainpy/inputs/currents.py
@@ -260,7 +260,7 @@ def wiener_process(duration, dt=None, n=1, t_start=0., t_end=None, seed=None):
dt = bm.get_dt() if dt is None else dt
is_float(dt, 'dt', allow_none=False, min_bound=0.)
is_integer(n, 'n', allow_none=False, min_bound=0)
- rng = bm.random.RandomState(seed)
+ rng = bm.random.get_rng(seed)
t_end = duration if t_end is None else t_end
i_start = int(t_start / dt)
i_end = int(t_end / dt)
@@ -302,7 +302,7 @@ def ou_process(mean, sigma, tau, duration, dt=None, n=1, t_start=0., t_end=None,
dt_sqrt = jnp.sqrt(dt)
is_float(dt, 'dt', allow_none=False, min_bound=0.)
is_integer(n, 'n', allow_none=False, min_bound=0)
- rng = bm.random.RandomState(seed)
+ rng = bm.random.get_rng(seed)
x = bm.Variable(jnp.ones(n) * mean)
def _f(t):
diff --git a/brainpy/integrators/base.py b/brainpy/integrators/base.py
index 6749a6d1e..056d5bd4f 100644
--- a/brainpy/integrators/base.py
+++ b/brainpy/integrators/base.py
@@ -4,7 +4,6 @@
from typing import Dict, Sequence, Union
import brainpy.math as bm
-from brainpy.base.base import BrainPyObject
from brainpy.errors import DiffEqError
from brainpy.integrators.constants import DT
from brainpy.check import is_float, is_dict_data
@@ -14,7 +13,7 @@
]
-class AbstractIntegrator(BrainPyObject):
+class AbstractIntegrator(bm.BrainPyObject):
"""Basic Integrator Class."""
# func_name
diff --git a/brainpy/integrators/constants.py b/brainpy/integrators/constants.py
index 3cc757e5e..b1d5280fa 100644
--- a/brainpy/integrators/constants.py
+++ b/brainpy/integrators/constants.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
-from brainpy.base import naming
+import brainpy.math as bm
__all__ = [
'DT',
@@ -131,15 +131,15 @@
def unique_name(type):
if type == 'ode':
- return naming.get_unique_name(ODE_INT)
+ return bm.get_unique_name(ODE_INT)
elif type == 'sde':
- return naming.get_unique_name(SDE_INT)
+ return bm.get_unique_name(SDE_INT)
elif type == 'dde':
- return naming.get_unique_name(DDE_INT)
+ return bm.get_unique_name(DDE_INT)
elif type == 'fde':
- return naming.get_unique_name(FDE_INT)
+ return bm.get_unique_name(FDE_INT)
elif type == 'pde':
- return naming.get_unique_name(PDE_INT)
+ return bm.get_unique_name(PDE_INT)
else:
raise ValueError(f'Unknown differential equation type: {type}')
diff --git a/brainpy/integrators/joint_eq.py b/brainpy/integrators/joint_eq.py
index 3f0130fe3..d63a7d326 100644
--- a/brainpy/integrators/joint_eq.py
+++ b/brainpy/integrators/joint_eq.py
@@ -2,8 +2,7 @@
import inspect
-from brainpy import errors
-from brainpy.base import Collector
+from brainpy import errors, math as bm
__all__ = [
'JointEq',
@@ -199,7 +198,7 @@ def __init__(self, *eqs):
def __call__(self, *args, **kwargs):
# format arguments
- params_in = Collector()
+ params_in = bm.Collector()
for i, arg in enumerate(args):
if i < len(self.arg_keys):
params_in[self.arg_keys[i]] = arg
diff --git a/brainpy/integrators/ode/exponential.py b/brainpy/integrators/ode/exponential.py
index 991c09fbe..f74871b04 100644
--- a/brainpy/integrators/ode/exponential.py
+++ b/brainpy/integrators/ode/exponential.py
@@ -109,7 +109,6 @@
from functools import wraps
from brainpy import math as bm, errors
-from brainpy.base.collector import Collector
from brainpy.integrators import constants as C, utils, joint_eq
from brainpy.integrators.ode.base import ODEIntegrator
from .generic import register_ode_integrator
@@ -323,7 +322,7 @@ def build(self):
@wraps(self.f)
def integral_func(*args, **kwargs):
# format arguments
- params_in = Collector()
+ params_in = bm.Collector()
for i, arg in enumerate(args):
params_in[all_vps[i]] = arg
params_in.update(kwargs)
diff --git a/brainpy/integrators/runner.py b/brainpy/integrators/runner.py
index d09db635e..b7640a0cc 100644
--- a/brainpy/integrators/runner.py
+++ b/brainpy/integrators/runner.py
@@ -12,7 +12,6 @@
import tqdm.auto
from jax.experimental.host_callback import id_tap
-from brainpy.base import Collector
from brainpy import math as bm
from brainpy.errors import RunningError, MonitorError
from brainpy.integrators.base import Integrator
@@ -169,8 +168,14 @@ def __init__(
self.variables[k][:] = inits[k]
# format string monitors
- monitors = self._format_seq_monitors(monitors)
- monitors = {k: (self.variables[k], i) for k, i in monitors}
+ if isinstance(monitors, (tuple, list)):
+ monitors = self._format_seq_monitors(monitors)
+ monitors = {k: (self.variables[k], i) for k, i in monitors}
+ elif isinstance(monitors, dict):
+ monitors = self._format_dict_monitors(monitors)
+ monitors = {k: ((self.variables[i], i) if isinstance(i, str) else i) for k, i in monitors.items()}
+ else:
+ raise ValueError
# initialize super class
super(IntegratorRunner, self).__init__(target=target,
@@ -219,12 +224,6 @@ def __init__(
else:
self._dyn_args = dict()
- # monitors
- for k in self.mon.var_names:
- if k not in self.target.variables:
- raise MonitorError(f'Variable "{k}" to monitor is not defined '
- f'in the integrator {self.target}.')
-
# start simulation time and index
self.start_t = bm.Variable(bm.zeros(1))
self.idx = bm.Variable(bm.zeros(1, dtype=bm.int_))
@@ -239,7 +238,7 @@ def _run_fun_integration(self, static_args, dyn_args, times, indices):
def _step_fun_integrator(self, static_args, dyn_args, t, i):
# arguments
- kwargs = Collector(dt=self.dt, t=t)
+ kwargs = bm.Collector(dt=self.dt, t=t)
kwargs.update(static_args)
kwargs.update(dyn_args)
kwargs.update({k: v.value for k, v in self.variables.items()})
diff --git a/brainpy/integrators/sde/base.py b/brainpy/integrators/sde/base.py
index 6d4f0c912..c5f86e4a7 100644
--- a/brainpy/integrators/sde/base.py
+++ b/brainpy/integrators/sde/base.py
@@ -73,7 +73,7 @@ def __init__(
self.wiener_type = wiener_type # wiener process type
# random seed
- self.rng = bm.random.RandomState()
+ self.rng = bm.random.get_rng()
# code scope
self.code_scope = {constants.F: f, constants.G: g, 'math': bm, 'random': self.rng}
diff --git a/brainpy/integrators/sde/normal.py b/brainpy/integrators/sde/normal.py
index 05c3b2838..0bd907781 100644
--- a/brainpy/integrators/sde/normal.py
+++ b/brainpy/integrators/sde/normal.py
@@ -5,7 +5,6 @@
import jax.numpy as jnp
from brainpy import errors, math as bm
-from brainpy.base import Collector
from brainpy.integrators import constants, utils, joint_eq
from brainpy.integrators.sde.base import SDEIntegrator
from .generic import register_sde_integrator
@@ -574,7 +573,7 @@ def build(self):
def integral_func(*args, **kwargs):
# format arguments
- params_in = Collector()
+ params_in = bm.Collector()
for i, arg in enumerate(args):
params_in[all_vps[i]] = arg
params_in.update(kwargs)
diff --git a/brainpy/integrators/sde/srk_strong.py b/brainpy/integrators/sde/srk_strong.py
index 0cac592ce..74d02d8fa 100644
--- a/brainpy/integrators/sde/srk_strong.py
+++ b/brainpy/integrators/sde/srk_strong.py
@@ -386,7 +386,7 @@ def _srk2_wrapper():
def _wrap(wrapper, f, g, dt, sde_type, var_type, wiener_type, show_code, num_iter):
- """The base function to format a SRK method.
+ """The brainpy_object function to format a SRK method.
Parameters
----------
diff --git a/brainpy/math/__init__.py b/brainpy/math/__init__.py
index cc73c3587..a08837c82 100644
--- a/brainpy/math/__init__.py
+++ b/brainpy/math/__init__.py
@@ -31,6 +31,9 @@
# data structure
+from .object_transform.base_object import *
+from .object_transform.base_transform import *
+from .object_transform.collector import *
from .ndarray import *
from .delayvars import *
diff --git a/brainpy/math/_utils.py b/brainpy/math/_utils.py
index 9432a2b22..063a9fc8c 100644
--- a/brainpy/math/_utils.py
+++ b/brainpy/math/_utils.py
@@ -1,11 +1,13 @@
# -*- coding: utf-8 -*-
+import functools
from typing import Callable
+import jax
+import numpy as np
+from jax.tree_util import tree_map
-__all__ = [
- 'wraps'
-]
+from .ndarray import Array
def wraps(fun: Callable):
@@ -17,6 +19,7 @@ def wraps(fun: Callable):
this reason, it is important that parameter names match those in the original
numpy function.
"""
+
def wrap(op):
docstr = getattr(fun, "__doc__", None)
op.__doc__ = docstr
@@ -29,4 +32,31 @@ def wrap(op):
else:
setattr(op, attr, value)
return op
+
return wrap
+
+
+def _as_jax_array(a):
+ return a.value if isinstance(a, Array) else a
+
+
+def _as_brainpy_array(a):
+ return Array(a) if isinstance(a, (np.ndarray, jax.Array)) else a
+
+
+def _is_leaf(a):
+ return isinstance(a, Array)
+
+
+def _compatible_with_brainpy_array(fun: Callable, return_brainpy_array: bool = False):
+ @functools.wraps(fun)
+ def new_fun(*args, **kwargs):
+ args = tree_map(_as_jax_array, args, is_leaf=_is_leaf)
+ if len(kwargs):
+ kwargs = tree_map(_as_jax_array, kwargs, is_leaf=_is_leaf)
+ r = fun(*args, **kwargs)
+ return tree_map(_as_brainpy_array, r) if return_brainpy_array else r
+
+ new_fun.__doc__ = getattr(fun, "__doc__", None)
+
+ return new_fun
diff --git a/brainpy/math/delayvars.py b/brainpy/math/delayvars.py
index 68292b941..f7bcad028 100644
--- a/brainpy/math/delayvars.py
+++ b/brainpy/math/delayvars.py
@@ -1,18 +1,18 @@
# -*- coding: utf-8 -*-
-from typing import Union, Callable, Tuple
+from typing import Union, Callable
import jax.numpy as jnp
from jax import vmap
from jax.lax import cond, stop_gradient
from brainpy import check
-from brainpy.base.base import BrainPyObject
-from brainpy.errors import UnsupportedError
-from brainpy.math import numpy_ops as bm
-from brainpy.math.ndarray import ndarray, Variable, Array
-from brainpy.math.environment import get_dt, get_float
from brainpy.check import is_float, is_integer, jit_error_checking
+from brainpy.errors import UnsupportedError
+from . import numpy_ops as bm
+from .object_transform.base_object import BrainPyObject
+from .environment import get_dt, get_float
+from .ndarray import ndarray, Variable, Array
__all__ = [
'AbstractDelay',
diff --git a/brainpy/math/environment.py b/brainpy/math/environment.py
index fa7b9c763..56190f0dd 100644
--- a/brainpy/math/environment.py
+++ b/brainpy/math/environment.py
@@ -6,12 +6,12 @@
import os
import re
import sys
-import warnings
from typing import Any, Callable, TypeVar, cast
from jax import config, numpy as jnp, devices
from jax.lib import xla_bridge
+from brainpy import errors
from . import modes
bm = None
@@ -63,8 +63,8 @@ def ditype():
.. deprecated:: 2.3.1
Use `brainpy.math.int_` instead.
"""
- warnings.warn('Get default integer data type through `ditype()` has been deprecated. \n'
- 'Use `brainpy.math.int_` instead.')
+ # raise errors.NoLongerSupportError('\nGet default integer data type through `ditype()` has been deprecated. \n'
+ # 'Use `brainpy.math.int_` instead.')
global bm
if bm is None: from brainpy import math as bm
return bm.int_
@@ -77,8 +77,8 @@ def dftype():
Use `brainpy.math.float_` instead.
"""
- warnings.warn('Get default floating data type through `dftype()` has been deprecated. \n'
- 'Use `brainpy.math.float_` instead.')
+ # raise errors.NoLongerSupportError('\nGet default floating data type through `dftype()` has been deprecated. \n'
+ # 'Use `brainpy.math.float_` instead.')
global bm
if bm is None: from brainpy import math as bm
return bm.float_
diff --git a/brainpy/math/numpy_ops_new.py b/brainpy/math/numpy_ops_new.py
new file mode 100644
index 000000000..b575b3b60
--- /dev/null
+++ b/brainpy/math/numpy_ops_new.py
@@ -0,0 +1,750 @@
+# -*- coding: utf-8 -*-
+
+import jax.numpy as jnp
+import numpy as np
+from jax.tree_util import tree_map
+
+from brainpy.math.ndarray import Array, Variable
+from ._utils import _compatible_with_brainpy_array
+
+
+__all__ = [
+ # math funcs
+ 'real', 'imag', 'conj', 'conjugate', 'ndim', 'isreal', 'isscalar',
+ 'add', 'reciprocal', 'negative', 'positive', 'multiply', 'divide',
+ 'power', 'subtract', 'true_divide', 'floor_divide', 'float_power',
+ 'fmod', 'mod', 'modf', 'divmod', 'remainder', 'abs', 'exp', 'exp2',
+ 'expm1', 'log', 'log10', 'log1p', 'log2', 'logaddexp', 'logaddexp2',
+ 'lcm', 'gcd', 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan',
+ 'arctan2', 'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan',
+ 'tanh', 'deg2rad', 'hypot', 'rad2deg', 'degrees', 'radians', 'round',
+ 'around', 'round_', 'rint', 'floor', 'ceil', 'trunc', 'fix', 'prod',
+ 'sum', 'diff', 'median', 'nancumprod', 'nancumsum', 'nanprod', 'nansum',
+ 'cumprod', 'cumsum', 'ediff1d', 'cross', 'trapz', 'isfinite', 'isinf',
+ 'isnan', 'signbit', 'copysign', 'nextafter', 'ldexp', 'frexp', 'convolve',
+ 'sqrt', 'cbrt', 'square', 'absolute', 'fabs', 'sign', 'heaviside',
+ 'maximum', 'minimum', 'fmax', 'fmin', 'interp', 'clip', 'angle',
+
+ # Elementwise bit operations
+ 'bitwise_and', 'bitwise_not', 'bitwise_or', 'bitwise_xor',
+ 'invert', 'left_shift', 'right_shift',
+
+ # logic funcs
+ 'equal', 'not_equal', 'greater', 'greater_equal', 'less', 'less_equal',
+ 'array_equal', 'isclose', 'allclose', 'logical_and', 'logical_not',
+ 'logical_or', 'logical_xor', 'all', 'any', "alltrue", 'sometrue',
+
+ # array manipulation
+ 'shape', 'size', 'reshape', 'ravel', 'moveaxis', 'transpose', 'swapaxes',
+ 'concatenate', 'stack', 'vstack', 'hstack', 'dstack', 'column_stack',
+ 'split', 'dsplit', 'hsplit', 'vsplit', 'tile', 'repeat', 'unique',
+ 'append', 'flip', 'fliplr', 'flipud', 'roll', 'atleast_1d', 'atleast_2d',
+ 'atleast_3d', 'expand_dims', 'squeeze', 'sort', 'argsort', 'argmax', 'argmin',
+ 'argwhere', 'nonzero', 'flatnonzero', 'where', 'searchsorted', 'extract',
+ 'count_nonzero', 'max', 'min', 'amax', 'amin',
+
+ # array creation
+ 'empty', 'empty_like', 'ones', 'ones_like', 'zeros', 'zeros_like', 'full',
+ 'full_like', 'eye', 'identity', 'array', 'asarray', 'arange', 'linspace',
+ 'logspace', 'meshgrid', 'diag', 'tri', 'tril', 'triu', 'vander', 'fill_diagonal',
+ 'array_split',
+
+ # indexing funcs
+ 'nonzero', 'where', 'tril_indices', 'tril_indices_from', 'triu_indices',
+ 'triu_indices_from', 'take', 'diag', 'select',
+
+ # statistic funcs
+ 'nanmin', 'nanmax', 'ptp', 'percentile', 'nanpercentile', 'quantile', 'nanquantile',
+ 'median', 'average', 'mean', 'std', 'var', 'nanmedian', 'nanmean', 'nanstd', 'nanvar',
+ 'corrcoef', 'correlate', 'cov', 'histogram', 'bincount', 'digitize',
+
+ # window funcs
+ 'bartlett', 'blackman', 'hamming', 'hanning', 'kaiser',
+
+ # constants
+ 'e', 'pi', 'inf',
+
+ # linear algebra
+ 'dot', 'vdot', 'inner', 'outer', 'kron', 'matmul', 'trace',
+
+ # data types
+ 'dtype', 'finfo', 'iinfo', 'uint8', 'uint16', 'uint32', 'uint64',
+ 'int8', 'int16', 'int32', 'int64', 'float16', 'float32',
+ 'float64', 'complex64', 'complex128',
+
+ # more
+ 'product', 'row_stack', 'apply_over_axes', 'apply_along_axis', 'array_equiv',
+ 'array_repr', 'array_str', 'block', 'broadcast_arrays', 'broadcast_shapes',
+ 'broadcast_to', 'compress', 'cumproduct', 'diag_indices', 'diag_indices_from',
+ 'diagflat', 'diagonal', 'einsum', 'einsum_path', 'geomspace', 'gradient',
+ 'histogram2d', 'histogram_bin_edges', 'histogramdd', 'i0', 'in1d', 'indices',
+ 'insert', 'intersect1d', 'iscomplex', 'isin', 'ix_', 'lexsort', 'load',
+ 'save', 'savez', 'mask_indices', 'msort', 'nan_to_num', 'nanargmax', 'setdiff1d',
+ 'nanargmin', 'pad', 'poly', 'polyadd', 'polyder', 'polyfit', 'polyint',
+ 'polymul', 'polysub', 'polyval', 'resize', 'rollaxis', 'roots', 'rot90',
+ 'setxor1d', 'tensordot', 'trim_zeros', 'union1d', 'unravel_index', 'unwrap',
+ 'take_along_axis', 'can_cast', 'choose', 'copy', 'frombuffer', 'fromfile',
+ 'fromfunction', 'fromiter', 'fromstring', 'get_printoptions', 'iscomplexobj',
+ 'isneginf', 'isposinf', 'isrealobj', 'issubdtype', 'issubsctype', 'iterable',
+ 'packbits', 'piecewise', 'printoptions', 'set_printoptions', 'promote_types',
+ 'ravel_multi_index', 'result_type', 'sort_complex', 'unpackbits', 'delete',
+
+ # unique
+ 'add_docstring', 'add_newdoc', 'add_newdoc_ufunc', 'array2string', 'asanyarray',
+ 'ascontiguousarray', 'asfarray', 'asscalar', 'common_type', 'disp', 'genfromtxt',
+ 'loadtxt', 'info', 'issubclass_', 'place', 'polydiv', 'put', 'putmask', 'safe_eval',
+ 'savetxt', 'savez_compressed', 'show_config', 'typename', 'copyto', 'matrix', 'asmatrix', 'mat',
+
+ # others
+ 'clip_by_norm', 'remove_diag',
+ 'as_device_array', 'as_jax', 'as_ndarray', 'as_numpy',
+ 'as_variable',
+]
+
+_min = min
+_max = max
+
+
+# others
+# ------
+
+
+def remove_diag(arr):
+ """Remove the diagonal of the matrix.
+
+ Parameters
+ ----------
+ arr: ArrayType
+ The matrix with the shape of `(M, N)`.
+
+ Returns
+ -------
+ arr: Array
+ The matrix without diagonal which has the shape of `(M, N-1)`.
+ """
+ if arr.ndim != 2:
+ raise ValueError(f'Only support 2D matrix, while we got a {arr.ndim}D array.')
+ eyes = Array(ones(arr.shape, dtype=bool))
+ fill_diagonal(eyes, False)
+ return reshape(arr[eyes.value], (arr.shape[0], arr.shape[1] - 1))
+
+
+def as_jax(tensor, dtype=None):
+ """Convert the input to a ``jax.numpy.DeviceArray``.
+
+ Parameters
+ ----------
+ tensor: array_like
+ Input data, in any form that can be converted to an array. This
+ includes lists, lists of tuples, tuples, tuples of tuples, tuples
+ of lists, ArrayType.
+ dtype: data-type, optional
+ By default, the data-type is inferred from the input data.
+
+ Returns
+ -------
+ out : ArrayType
+ Array interpretation of `tensor`. No copy is performed if the input
+ is already an ndarray with matching dtype.
+ """
+ if isinstance(tensor, Array):
+ return tensor.to_jax(dtype)
+ elif isinstance(tensor, jnp.ndarray):
+ return tensor if (dtype is None) else jnp.asarray(tensor, dtype=dtype)
+ elif isinstance(tensor, np.ndarray):
+ return jnp.asarray(tensor, dtype=dtype)
+ else:
+ return jnp.asarray(tensor, dtype=dtype)
+
+
+as_device_array = as_jax
+
+
+def as_numpy(tensor, dtype=None):
+ """Convert the input to a ``numpy.ndarray``.
+
+ Parameters
+ ----------
+ tensor: array_like
+ Input data, in any form that can be converted to an array. This
+ includes lists, lists of tuples, tuples, tuples of tuples, tuples
+ of lists, ArrayType.
+ dtype: data-type, optional
+ By default, the data-type is inferred from the input data.
+
+ Returns
+ -------
+ out : ndarray
+ Array interpretation of `tensor`. No copy is performed if the input
+ is already an ndarray with matching dtype.
+ """
+ if isinstance(tensor, Array):
+ return tensor.to_numpy(dtype=dtype)
+ else:
+ return np.asarray(tensor, dtype=dtype)
+
+
+as_ndarray = as_numpy
+
+
+def as_variable(tensor, dtype=None):
+ """Convert the input to a ``brainpy.math.Variable``.
+
+ Parameters
+ ----------
+ tensor: array_like
+ Input data, in any form that can be converted to an array. This
+ includes lists, lists of tuples, tuples, tuples of tuples, tuples
+ of lists, ArrayType.
+ dtype: data-type, optional
+ By default, the data-type is inferred from the input data.
+
+ Returns
+ -------
+ out : ndarray
+ Array interpretation of `tensor`. No copy is performed if the input
+ is already an ndarray with matching dtype.
+ """
+ return Variable(asarray(tensor, dtype=dtype))
+
+
+def _as_jax_array_(obj):
+ return obj.value if isinstance(obj, Array) else obj
+
+
+def clip_by_norm(t, clip_norm, axis=None):
+ f = lambda l: l * clip_norm / maximum(sqrt(sum(l * l, axis=axis, keepdims=True)), clip_norm)
+ return tree_map(f, t)
+
+
+# array creation
+# --------------
+
+zeros = _compatible_with_brainpy_array(jnp.zeros, True)
+ones = _compatible_with_brainpy_array(jnp.ones, True)
+full = _compatible_with_brainpy_array(jnp.full, True)
+empty = _compatible_with_brainpy_array(jnp.empty, True)
+zeros_like = _compatible_with_brainpy_array(jnp.zeros_like, True)
+ones_like = _compatible_with_brainpy_array(jnp.ones_like, True)
+empty_like = _compatible_with_brainpy_array(jnp.empty_like, True)
+full_like = _compatible_with_brainpy_array(jnp.full_like, True)
+eye = _compatible_with_brainpy_array(jnp.eye, True)
+identity = _compatible_with_brainpy_array(jnp.identity, True)
+array = _compatible_with_brainpy_array(jnp.array, True)
+asarray = _compatible_with_brainpy_array(jnp.asarray, True)
+arange = _compatible_with_brainpy_array(jnp.arange, True)
+linspace = _compatible_with_brainpy_array(jnp.linspace, True)
+logspace = _compatible_with_brainpy_array(jnp.logspace, True)
+meshgrid = _compatible_with_brainpy_array(jnp.meshgrid, True)
+diag = _compatible_with_brainpy_array(jnp.diag, True)
+tri = _compatible_with_brainpy_array(jnp.tri, True)
+tril = _compatible_with_brainpy_array(jnp.tril, True)
+triu = _compatible_with_brainpy_array(jnp.triu, True)
+vander = _compatible_with_brainpy_array(jnp.vander, True)
+
+
+def asanyarray(a, dtype=None, order=None):
+ return asarray(a, dtype=dtype, order=order)
+
+
+def ascontiguousarray(a, dtype=None, order=None):
+ return asarray(a, dtype=dtype, order=order)
+
+
+def asfarray(a, dtype=np.float_):
+ if not np.issubdtype(dtype, np.inexact):
+ dtype = np.float_
+ return asarray(a, dtype=dtype)
+
+
+def fill_diagonal(a, val):
+ if not isinstance(a, Array):
+ raise ValueError(f'Must be a brainpy Array, but got {type(a)}')
+ if a.ndim < 2:
+ raise ValueError(f'Only support Array has dimension >= 2, but got {a.shape}')
+ val = _as_jax_array_(val)
+ i, j = jnp.diag_indices(_min(a.shape[-2:]))
+ a._value = a.value.at[..., i, j].set(val)
+
+
+# Others
+# ------
+
+delete = _compatible_with_brainpy_array(jnp.delete)
+take_along_axis = _compatible_with_brainpy_array(jnp.take_along_axis)
+block = _compatible_with_brainpy_array(jnp.block)
+broadcast_arrays = _compatible_with_brainpy_array(jnp.broadcast_arrays)
+broadcast_shapes = _compatible_with_brainpy_array(jnp.broadcast_shapes)
+broadcast_to = _compatible_with_brainpy_array(jnp.broadcast_to)
+compress = _compatible_with_brainpy_array(jnp.compress)
+diag_indices = _compatible_with_brainpy_array(jnp.diag_indices)
+diag_indices_from = _compatible_with_brainpy_array(jnp.diag_indices_from)
+diagflat = _compatible_with_brainpy_array(jnp.diagflat)
+diagonal = _compatible_with_brainpy_array(jnp.diagonal)
+einsum = _compatible_with_brainpy_array(jnp.einsum)
+einsum_path = _compatible_with_brainpy_array(jnp.einsum_path)
+geomspace = _compatible_with_brainpy_array(jnp.geomspace)
+gradient = _compatible_with_brainpy_array(jnp.gradient)
+histogram2d = _compatible_with_brainpy_array(jnp.histogram2d)
+histogram_bin_edges = _compatible_with_brainpy_array(jnp.histogram_bin_edges)
+histogramdd = _compatible_with_brainpy_array(jnp.histogramdd)
+i0 = _compatible_with_brainpy_array(jnp.i0)
+in1d = _compatible_with_brainpy_array(jnp.in1d)
+indices = _compatible_with_brainpy_array(jnp.indices)
+insert = _compatible_with_brainpy_array(jnp.insert)
+intersect1d = _compatible_with_brainpy_array(jnp.intersect1d)
+iscomplex = _compatible_with_brainpy_array(jnp.iscomplex)
+isin = _compatible_with_brainpy_array(jnp.isin)
+ix_ = _compatible_with_brainpy_array(jnp.ix_)
+lexsort = _compatible_with_brainpy_array(jnp.lexsort)
+load = _compatible_with_brainpy_array(jnp.load)
+save = _compatible_with_brainpy_array(jnp.save)
+savez = _compatible_with_brainpy_array(jnp.savez)
+mask_indices = _compatible_with_brainpy_array(jnp.mask_indices)
+msort = _compatible_with_brainpy_array(jnp.msort)
+nan_to_num = _compatible_with_brainpy_array(jnp.nan_to_num)
+nanargmax = _compatible_with_brainpy_array(jnp.nanargmax)
+nanargmin = _compatible_with_brainpy_array(jnp.nanargmin)
+pad = _compatible_with_brainpy_array(jnp.pad)
+poly = _compatible_with_brainpy_array(jnp.poly)
+polyadd = _compatible_with_brainpy_array(jnp.polyadd)
+polyder = _compatible_with_brainpy_array(jnp.polyder)
+polyfit = _compatible_with_brainpy_array(jnp.polyfit)
+polyint = _compatible_with_brainpy_array(jnp.polyint)
+polymul = _compatible_with_brainpy_array(jnp.polymul)
+polysub = _compatible_with_brainpy_array(jnp.polysub)
+polyval = _compatible_with_brainpy_array(jnp.polyval)
+resize = _compatible_with_brainpy_array(jnp.resize)
+rollaxis = _compatible_with_brainpy_array(jnp.rollaxis)
+roots = _compatible_with_brainpy_array(jnp.roots)
+rot90 = _compatible_with_brainpy_array(jnp.rot90)
+setdiff1d = _compatible_with_brainpy_array(jnp.setdiff1d)
+setxor1d = _compatible_with_brainpy_array(jnp.setxor1d)
+tensordot = _compatible_with_brainpy_array(jnp.tensordot)
+trim_zeros = _compatible_with_brainpy_array(jnp.trim_zeros)
+union1d = _compatible_with_brainpy_array(jnp.union1d)
+unravel_index = _compatible_with_brainpy_array(jnp.unravel_index)
+unwrap = _compatible_with_brainpy_array(jnp.unwrap)
+
+# math funcs
+# ----------
+isreal = _compatible_with_brainpy_array(jnp.isreal)
+isscalar = _compatible_with_brainpy_array(jnp.isscalar)
+real = _compatible_with_brainpy_array(jnp.real)
+imag = _compatible_with_brainpy_array(jnp.imag)
+conj = _compatible_with_brainpy_array(jnp.conj)
+conjugate = _compatible_with_brainpy_array(jnp.conjugate)
+ndim = _compatible_with_brainpy_array(jnp.ndim)
+add = _compatible_with_brainpy_array(jnp.add)
+reciprocal = _compatible_with_brainpy_array(jnp.reciprocal)
+negative = _compatible_with_brainpy_array(jnp.negative)
+positive = _compatible_with_brainpy_array(jnp.positive)
+multiply = _compatible_with_brainpy_array(jnp.multiply)
+divide = _compatible_with_brainpy_array(jnp.divide)
+power = _compatible_with_brainpy_array(jnp.power)
+subtract = _compatible_with_brainpy_array(jnp.subtract)
+true_divide = _compatible_with_brainpy_array(jnp.true_divide)
+floor_divide = _compatible_with_brainpy_array(jnp.floor_divide)
+float_power = _compatible_with_brainpy_array(jnp.float_power)
+fmod = _compatible_with_brainpy_array(jnp.fmod)
+mod = _compatible_with_brainpy_array(jnp.mod)
+divmod = _compatible_with_brainpy_array(jnp.divmod)
+remainder = _compatible_with_brainpy_array(jnp.remainder)
+modf = _compatible_with_brainpy_array(jnp.modf)
+abs = _compatible_with_brainpy_array(jnp.abs)
+absolute = _compatible_with_brainpy_array(jnp.absolute)
+exp = _compatible_with_brainpy_array(jnp.exp)
+exp2 = _compatible_with_brainpy_array(jnp.exp2)
+expm1 = _compatible_with_brainpy_array(jnp.expm1)
+log = _compatible_with_brainpy_array(jnp.log)
+log10 = _compatible_with_brainpy_array(jnp.log10)
+log1p = _compatible_with_brainpy_array(jnp.log1p)
+log2 = _compatible_with_brainpy_array(jnp.log2)
+logaddexp = _compatible_with_brainpy_array(jnp.logaddexp)
+logaddexp2 = _compatible_with_brainpy_array(jnp.logaddexp2)
+lcm = _compatible_with_brainpy_array(jnp.lcm)
+gcd = _compatible_with_brainpy_array(jnp.gcd)
+arccos = _compatible_with_brainpy_array(jnp.arccos)
+arccosh = _compatible_with_brainpy_array(jnp.arccosh)
+arcsin = _compatible_with_brainpy_array(jnp.arcsin)
+arcsinh = _compatible_with_brainpy_array(jnp.arcsinh)
+arctan = _compatible_with_brainpy_array(jnp.arctan)
+arctan2 = _compatible_with_brainpy_array(jnp.arctan2)
+arctanh = _compatible_with_brainpy_array(jnp.arctanh)
+cos = _compatible_with_brainpy_array(jnp.cos)
+cosh = _compatible_with_brainpy_array(jnp.cosh)
+sin = _compatible_with_brainpy_array(jnp.sin)
+sinc = _compatible_with_brainpy_array(jnp.sinc)
+sinh = _compatible_with_brainpy_array(jnp.sinh)
+tan = _compatible_with_brainpy_array(jnp.tan)
+tanh = _compatible_with_brainpy_array(jnp.tanh)
+deg2rad = _compatible_with_brainpy_array(jnp.deg2rad)
+rad2deg = _compatible_with_brainpy_array(jnp.rad2deg)
+degrees = _compatible_with_brainpy_array(jnp.degrees)
+radians = _compatible_with_brainpy_array(jnp.radians)
+hypot = _compatible_with_brainpy_array(jnp.hypot)
+round = _compatible_with_brainpy_array(jnp.round)
+around = round
+round_ = round
+rint = _compatible_with_brainpy_array(jnp.rint)
+floor = _compatible_with_brainpy_array(jnp.floor)
+ceil = _compatible_with_brainpy_array(jnp.ceil)
+trunc = _compatible_with_brainpy_array(jnp.trunc)
+fix = _compatible_with_brainpy_array(jnp.fix)
+prod = _compatible_with_brainpy_array(jnp.prod)
+sum = _compatible_with_brainpy_array(jnp.sum)
+diff = _compatible_with_brainpy_array(jnp.diff)
+median = _compatible_with_brainpy_array(jnp.median)
+nancumprod = _compatible_with_brainpy_array(jnp.nancumprod)
+nancumsum = _compatible_with_brainpy_array(jnp.nancumsum)
+cumprod = _compatible_with_brainpy_array(jnp.cumprod)
+cumproduct = cumprod
+cumsum = _compatible_with_brainpy_array(jnp.cumsum)
+nanprod = _compatible_with_brainpy_array(jnp.nanprod)
+nansum = _compatible_with_brainpy_array(jnp.nansum)
+ediff1d = _compatible_with_brainpy_array(jnp.ediff1d)
+cross = _compatible_with_brainpy_array(jnp.cross)
+trapz = _compatible_with_brainpy_array(jnp.trapz)
+isfinite = _compatible_with_brainpy_array(jnp.isfinite)
+isinf = _compatible_with_brainpy_array(jnp.isinf)
+isnan = _compatible_with_brainpy_array(jnp.isnan)
+signbit = _compatible_with_brainpy_array(jnp.signbit)
+nextafter = _compatible_with_brainpy_array(jnp.nextafter)
+copysign = _compatible_with_brainpy_array(jnp.copysign)
+ldexp = _compatible_with_brainpy_array(jnp.ldexp)
+frexp = _compatible_with_brainpy_array(jnp.frexp)
+convolve = _compatible_with_brainpy_array(jnp.convolve)
+sqrt = _compatible_with_brainpy_array(jnp.sqrt)
+cbrt = _compatible_with_brainpy_array(jnp.cbrt)
+square = _compatible_with_brainpy_array(jnp.square)
+fabs = _compatible_with_brainpy_array(jnp.fabs)
+sign = _compatible_with_brainpy_array(jnp.sign)
+heaviside = _compatible_with_brainpy_array(jnp.heaviside)
+maximum = _compatible_with_brainpy_array(jnp.maximum)
+minimum = _compatible_with_brainpy_array(jnp.minimum)
+fmax = _compatible_with_brainpy_array(jnp.fmax)
+fmin = _compatible_with_brainpy_array(jnp.fmin)
+interp = _compatible_with_brainpy_array(jnp.interp)
+clip = _compatible_with_brainpy_array(jnp.clip)
+angle = _compatible_with_brainpy_array(jnp.angle)
+bitwise_not = _compatible_with_brainpy_array(jnp.bitwise_not)
+invert = _compatible_with_brainpy_array(jnp.invert)
+bitwise_and = _compatible_with_brainpy_array(jnp.bitwise_and)
+bitwise_or = _compatible_with_brainpy_array(jnp.bitwise_or)
+bitwise_xor = _compatible_with_brainpy_array(jnp.bitwise_xor)
+left_shift = _compatible_with_brainpy_array(jnp.left_shift)
+right_shift = _compatible_with_brainpy_array(jnp.right_shift)
+equal = _compatible_with_brainpy_array(jnp.equal)
+not_equal = _compatible_with_brainpy_array(jnp.not_equal)
+greater = _compatible_with_brainpy_array(jnp.greater)
+greater_equal = _compatible_with_brainpy_array(jnp.greater_equal)
+less = _compatible_with_brainpy_array(jnp.less)
+less_equal = _compatible_with_brainpy_array(jnp.less_equal)
+array_equal = _compatible_with_brainpy_array(jnp.array_equal)
+isclose = _compatible_with_brainpy_array(jnp.isclose)
+allclose = _compatible_with_brainpy_array(jnp.allclose)
+logical_not = _compatible_with_brainpy_array(jnp.logical_not)
+logical_and = _compatible_with_brainpy_array(jnp.logical_and)
+logical_or = _compatible_with_brainpy_array(jnp.logical_or)
+logical_xor = _compatible_with_brainpy_array(jnp.logical_xor)
+all = _compatible_with_brainpy_array(jnp.all)
+any = _compatible_with_brainpy_array(jnp.any)
+alltrue = all
+sometrue = any
+
+# array manipulation
+# ------------------
+
+shape = _compatible_with_brainpy_array(jnp.shape)
+size = _compatible_with_brainpy_array(jnp.size)
+reshape = _compatible_with_brainpy_array(jnp.reshape)
+ravel = _compatible_with_brainpy_array(jnp.ravel)
+moveaxis = _compatible_with_brainpy_array(jnp.moveaxis)
+transpose = _compatible_with_brainpy_array(jnp.transpose)
+swapaxes = _compatible_with_brainpy_array(jnp.swapaxes)
+concatenate = _compatible_with_brainpy_array(jnp.concatenate)
+stack = _compatible_with_brainpy_array(jnp.stack)
+vstack = _compatible_with_brainpy_array(jnp.vstack)
+product = prod
+row_stack = vstack
+hstack = _compatible_with_brainpy_array(jnp.hstack)
+dstack = _compatible_with_brainpy_array(jnp.dstack)
+column_stack = _compatible_with_brainpy_array(jnp.column_stack)
+split = _compatible_with_brainpy_array(jnp.split)
+dsplit = _compatible_with_brainpy_array(jnp.dsplit)
+hsplit = _compatible_with_brainpy_array(jnp.hsplit)
+vsplit = _compatible_with_brainpy_array(jnp.vsplit)
+tile = _compatible_with_brainpy_array(jnp.tile)
+repeat = _compatible_with_brainpy_array(jnp.repeat)
+unique = _compatible_with_brainpy_array(jnp.unique)
+append = _compatible_with_brainpy_array(jnp.append)
+flip = _compatible_with_brainpy_array(jnp.flip)
+fliplr = _compatible_with_brainpy_array(jnp.fliplr)
+flipud = _compatible_with_brainpy_array(jnp.flipud)
+roll = _compatible_with_brainpy_array(jnp.roll)
+atleast_1d = _compatible_with_brainpy_array(jnp.atleast_1d)
+atleast_2d = _compatible_with_brainpy_array(jnp.atleast_2d)
+atleast_3d = _compatible_with_brainpy_array(jnp.atleast_3d)
+expand_dims = _compatible_with_brainpy_array(jnp.expand_dims)
+squeeze = _compatible_with_brainpy_array(jnp.squeeze)
+sort = _compatible_with_brainpy_array(jnp.sort)
+argsort = _compatible_with_brainpy_array(jnp.argsort)
+argmax = _compatible_with_brainpy_array(jnp.argmax)
+argmin = _compatible_with_brainpy_array(jnp.argmin)
+argwhere = _compatible_with_brainpy_array(jnp.argwhere)
+nonzero = _compatible_with_brainpy_array(jnp.nonzero)
+flatnonzero = _compatible_with_brainpy_array(jnp.flatnonzero)
+where = _compatible_with_brainpy_array(jnp.where)
+searchsorted = _compatible_with_brainpy_array(jnp.searchsorted)
+extract = _compatible_with_brainpy_array(jnp.extract)
+count_nonzero = _compatible_with_brainpy_array(jnp.count_nonzero)
+max = _compatible_with_brainpy_array(jnp.max)
+min = _compatible_with_brainpy_array(jnp.min)
+amax = max
+amin = min
+apply_along_axis = _compatible_with_brainpy_array(jnp.apply_along_axis)
+apply_over_axes = _compatible_with_brainpy_array(jnp.apply_over_axes)
+array_equiv = _compatible_with_brainpy_array(jnp.array_equiv)
+array_repr = _compatible_with_brainpy_array(jnp.array_repr)
+array_str = _compatible_with_brainpy_array(jnp.array_str)
+array_split = _compatible_with_brainpy_array(jnp.array_split)
+
+# indexing funcs
+# --------------
+
+tril_indices = jnp.tril_indices
+triu_indices = jnp.triu_indices
+tril_indices_from = _compatible_with_brainpy_array(jnp.tril_indices_from)
+triu_indices_from = _compatible_with_brainpy_array(jnp.triu_indices_from)
+take = _compatible_with_brainpy_array(jnp.take)
+select = _compatible_with_brainpy_array(jnp.select)
+nanmin = _compatible_with_brainpy_array(jnp.nanmin)
+nanmax = _compatible_with_brainpy_array(jnp.nanmax)
+ptp = _compatible_with_brainpy_array(jnp.ptp)
+percentile = _compatible_with_brainpy_array(jnp.percentile)
+nanpercentile = _compatible_with_brainpy_array(jnp.nanpercentile)
+quantile = _compatible_with_brainpy_array(jnp.quantile)
+nanquantile = _compatible_with_brainpy_array(jnp.nanquantile)
+average = _compatible_with_brainpy_array(jnp.average)
+mean = _compatible_with_brainpy_array(jnp.mean)
+std = _compatible_with_brainpy_array(jnp.std)
+var = _compatible_with_brainpy_array(jnp.var)
+nanmedian = _compatible_with_brainpy_array(jnp.nanmedian)
+nanmean = _compatible_with_brainpy_array(jnp.nanmean)
+nanstd = _compatible_with_brainpy_array(jnp.nanstd)
+nanvar = _compatible_with_brainpy_array(jnp.nanvar)
+corrcoef = _compatible_with_brainpy_array(jnp.corrcoef)
+correlate = _compatible_with_brainpy_array(jnp.correlate)
+cov = _compatible_with_brainpy_array(jnp.cov)
+histogram = _compatible_with_brainpy_array(jnp.histogram)
+bincount = _compatible_with_brainpy_array(jnp.bincount)
+digitize = _compatible_with_brainpy_array(jnp.digitize)
+bartlett = _compatible_with_brainpy_array(jnp.bartlett)
+blackman = _compatible_with_brainpy_array(jnp.blackman)
+hamming = _compatible_with_brainpy_array(jnp.hamming)
+hanning = _compatible_with_brainpy_array(jnp.hanning)
+kaiser = _compatible_with_brainpy_array(jnp.kaiser)
+
+# constants
+# ---------
+
+e = jnp.e
+pi = jnp.pi
+inf = jnp.inf
+
+# linear algebra
+# --------------
+
+dot = _compatible_with_brainpy_array(jnp.dot)
+vdot = _compatible_with_brainpy_array(jnp.vdot)
+inner = _compatible_with_brainpy_array(jnp.inner)
+outer = _compatible_with_brainpy_array(jnp.outer)
+kron = _compatible_with_brainpy_array(jnp.kron)
+matmul = _compatible_with_brainpy_array(jnp.matmul)
+trace = _compatible_with_brainpy_array(jnp.trace)
+
+# data types
+# ----------
+
+dtype = jnp.dtype
+finfo = jnp.finfo
+iinfo = jnp.iinfo
+
+uint8 = jnp.uint8
+uint16 = jnp.uint16
+uint32 = jnp.uint32
+uint64 = jnp.uint64
+int8 = jnp.int8
+int16 = jnp.int16
+int32 = jnp.int32
+int64 = jnp.int64
+float16 = jnp.float16
+float32 = jnp.float32
+float64 = jnp.float64
+complex64 = jnp.complex64
+complex128 = jnp.complex128
+
+can_cast = _compatible_with_brainpy_array(jnp.can_cast)
+choose = _compatible_with_brainpy_array(jnp.choose)
+copy = _compatible_with_brainpy_array(jnp.copy)
+frombuffer = _compatible_with_brainpy_array(jnp.frombuffer)
+fromfile = _compatible_with_brainpy_array(jnp.fromfile)
+fromfunction = _compatible_with_brainpy_array(jnp.fromfunction)
+fromiter = _compatible_with_brainpy_array(jnp.fromiter)
+fromstring = _compatible_with_brainpy_array(jnp.fromstring)
+get_printoptions = np.get_printoptions
+iscomplexobj = _compatible_with_brainpy_array(jnp.iscomplexobj)
+isneginf = _compatible_with_brainpy_array(jnp.isneginf)
+isposinf = _compatible_with_brainpy_array(jnp.isposinf)
+isrealobj = _compatible_with_brainpy_array(jnp.isrealobj)
+issubdtype = jnp.issubdtype
+issubsctype = jnp.issubsctype
+iterable = _compatible_with_brainpy_array(jnp.iterable)
+packbits = _compatible_with_brainpy_array(jnp.packbits)
+piecewise = _compatible_with_brainpy_array(jnp.piecewise)
+printoptions = np.printoptions
+set_printoptions = np.set_printoptions
+promote_types = _compatible_with_brainpy_array(jnp.promote_types)
+ravel_multi_index = _compatible_with_brainpy_array(jnp.ravel_multi_index)
+result_type = _compatible_with_brainpy_array(jnp.result_type)
+sort_complex = _compatible_with_brainpy_array(jnp.sort_complex)
+unpackbits = _compatible_with_brainpy_array(jnp.unpackbits)
+
+# Unique APIs
+# -----------
+
+add_docstring = np.add_docstring
+add_newdoc = np.add_newdoc
+add_newdoc_ufunc = np.add_newdoc_ufunc
+
+
+def array2string(a, max_line_width=None, precision=None,
+ suppress_small=None, separator=' ', prefix="",
+ style=np._NoValue, formatter=None, threshold=None,
+ edgeitems=None, sign=None, floatmode=None, suffix="",
+ legacy=None):
+ a = as_numpy(a)
+ return array2string(a, max_line_width=max_line_width, precision=precision,
+ suppress_small=suppress_small, separator=separator, prefix=prefix,
+ style=style, formatter=formatter, threshold=threshold,
+ edgeitems=edgeitems, sign=sign, floatmode=floatmode, suffix=suffix,
+ legacy=legacy)
+
+
+def asscalar(a):
+ return a.item()
+
+
+array_type = [[np.half, np.single, np.double, np.longdouble],
+ [None, np.csingle, np.cdouble, np.clongdouble]]
+array_precision = {np.half: 0,
+ np.single: 1,
+ np.double: 2,
+ np.longdouble: 3,
+ np.csingle: 1,
+ np.cdouble: 2,
+ np.clongdouble: 3}
+
+
+def common_type(*arrays):
+ is_complex = False
+ precision = 0
+ for a in arrays:
+ t = a.dtype.type
+ if iscomplexobj(a):
+ is_complex = True
+ if issubclass(t, jnp.integer):
+ p = 2 # array_precision[_nx.double]
+ else:
+ p = array_precision.get(t, None)
+ if p is None:
+ raise TypeError("can't get common type for non-numeric array")
+ precision = max(precision, p)
+ if is_complex:
+ return array_type[1][precision]
+ else:
+ return array_type[0][precision]
+
+
+disp = np.disp
+
+genfromtxt = lambda *args, **kwargs: asarray(np.genfromtxt(*args, **kwargs))
+loadtxt = lambda *args, **kwargs: asarray(np.loadtxt(*args, **kwargs))
+
+info = np.info
+issubclass_ = np.issubclass_
+
+
+def place(arr, mask, vals):
+ if not isinstance(arr, Array):
+ raise ValueError(f'Must be an instance of brainpy Array, but we got {type(arr)}')
+ arr[mask] = vals
+
+
+polydiv = _compatible_with_brainpy_array(jnp.polydiv)
+
+
+def put(a, ind, v):
+ if not isinstance(a, Array):
+ raise ValueError(f'Must be an instance of brainpy Array, but we got {type(a)}')
+ a[ind] = v
+
+
+def putmask(a, mask, values):
+ if not isinstance(a, Array):
+ raise ValueError(f'Must be an instance of brainpy Array, but we got {type(a)}')
+ if a.shape != values.shape:
+ raise ValueError('Only support the shapes of "a" and "values" are consistent.')
+ a[mask] = values
+
+
+def safe_eval(source):
+ return tree_map(Array, np.safe_eval(source))
+
+
+def savetxt(fname, X, fmt='%.18e', delimiter=' ', newline='\n', header='',
+ footer='', comments='# ', encoding=None):
+ X = as_numpy(X)
+ np.savetxt(fname, X, fmt=fmt, delimiter=delimiter, newline=newline, header=header,
+ footer=footer, comments=comments, encoding=encoding)
+
+
+def savez_compressed(file, *args, **kwds):
+ args = tuple([(as_numpy(a) if isinstance(a, (jnp.ndarray, Array)) else a) for a in args])
+ kwds = {k: (as_numpy(v) if isinstance(v, (jnp.ndarray, Array)) else v)
+ for k, v in kwds.items()}
+ np.savez_compressed(file, *args, **kwds)
+
+
+show_config = np.show_config
+typename = np.typename
+
+
+def copyto(dst, src):
+ if not isinstance(dst, Array):
+ raise ValueError('dst must be an instance of ArrayType.')
+ dst[:] = src
+
+
+def matrix(data, dtype=None):
+ data = array(data, copy=True, dtype=dtype)
+ if data.ndim > 2:
+ raise ValueError(f'shape too large {data.shape} to be a matrix.')
+ if data.ndim != 2:
+ for i in range(2 - data.ndim):
+ data = expand_dims(data, 0)
+ return data
+
+
+def asmatrix(data, dtype=None):
+ data = array(data, dtype=dtype)
+ if data.ndim > 2:
+ raise ValueError(f'shape too large {data.shape} to be a matrix.')
+ if data.ndim != 2:
+ for i in range(2 - data.ndim):
+ data = expand_dims(data, 0)
+ return data
+
+
+def mat(data, dtype=None):
+ return asmatrix(data, dtype=dtype)
diff --git a/brainpy/math/object_transform/__init__.py b/brainpy/math/object_transform/__init__.py
index d08f904e4..53b64c9ef 100644
--- a/brainpy/math/object_transform/__init__.py
+++ b/brainpy/math/object_transform/__init__.py
@@ -1,6 +1,24 @@
# -*- coding: utf-8 -*-
+"""
+The ``brainpy_object`` module for whole BrainPy ecosystem.
+
+- This module provides the most fundamental class ``BrainPyObject``,
+ and its associated helper class ``Collector`` and ``ArrayCollector``.
+- For each instance of "BrainPyObject" class, users can retrieve all
+ the variables (or trainable variables), integrators, and nodes.
+- This module also provides a ``FunAsObject`` class to wrap user-defined
+ functions. In each function, maybe several nodes are used, and
+ users can initialize a ``FunAsObject`` by providing the nodes used
+ in the function. Unfortunately, ``FunAsObject`` class does not have
+ the ability to gather nodes automatically.
+
+Details please see the following.
+"""
from . import (
+ base_object,
+ base_transform,
+ collector,
autograd,
controls,
jit,
diff --git a/brainpy/math/object_transform/_utils.py b/brainpy/math/object_transform/_utils.py
index a0d678be0..7eb9d8633 100644
--- a/brainpy/math/object_transform/_utils.py
+++ b/brainpy/math/object_transform/_utils.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
-from typing import Tuple, Dict
-from brainpy.base import BrainPyObject, ArrayCollector
+from typing import Dict
+from .base_object import BrainPyObject, ArrayCollector
__all__ = [
'infer_dyn_vars',
diff --git a/brainpy/math/object_transform/autograd.py b/brainpy/math/object_transform/autograd.py
index ae7aafaa7..2ee4d3a1d 100644
--- a/brainpy/math/object_transform/autograd.py
+++ b/brainpy/math/object_transform/autograd.py
@@ -18,9 +18,9 @@
from jax.util import safe_map
from brainpy import errors, tools, check
-from brainpy.base import BrainPyObject
-from brainpy.math.ndarray import Array, Variable, add_context, del_context
-from .base import ObjectTransform
+from .base_object import BrainPyObject
+from .base_transform import ObjectTransform
+from ..ndarray import Array, Variable, add_context, del_context
__all__ = [
'grad', # gradient of scalar function
diff --git a/brainpy/math/object_transform/autograd_old.py b/brainpy/math/object_transform/autograd_old.py
index b726201f7..2169ed55e 100644
--- a/brainpy/math/object_transform/autograd_old.py
+++ b/brainpy/math/object_transform/autograd_old.py
@@ -16,10 +16,10 @@
from jax.util import safe_map
from brainpy import errors, tools
-from brainpy.base import get_unique_name, ArrayCollector
-from brainpy.math.ndarray import Array, add_context, del_context
from ._utils import infer_dyn_vars
-from .base import ObjectTransform
+from .base_object import get_unique_name, ArrayCollector
+from .base_transform import ObjectTransform
+from ..ndarray import Array, add_context, del_context
__all__ = [
'grad', # gradient of scalar function
diff --git a/brainpy/math/object_transform/base_object.py b/brainpy/math/object_transform/base_object.py
new file mode 100644
index 000000000..f5b3b8a85
--- /dev/null
+++ b/brainpy/math/object_transform/base_object.py
@@ -0,0 +1,533 @@
+# -*- coding: utf-8 -*-
+
+import os
+import logging
+import warnings
+from collections import namedtuple
+from typing import Any, Tuple, Callable, Sequence, Dict, Union
+
+from brainpy import errors
+from .collector import Collector, ArrayCollector
+from ..ndarray import Variable, VariableView, TrainVar
+
+StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])
+
+__all__ = [
+ 'check_name_uniqueness',
+ 'get_unique_name',
+ 'clear_name_cache',
+
+ 'BrainPyObject', 'Base', 'FunAsObject',
+]
+
+logger = logging.getLogger('brainpy.brainpy_object')
+
+_name2id = dict()
+_typed_names = {}
+
+
+def check_name_uniqueness(name, obj):
+ """Check the uniqueness of the name for the object type."""
+ if not name.isidentifier():
+ raise errors.BrainPyError(f'"{name}" isn\'t a valid identifier '
+ f'according to Python language definition. '
+ f'Please choose another name.')
+ if name in _name2id:
+ if _name2id[name] != id(obj):
+ raise errors.UniqueNameError(
+ f'In BrainPy, each object should have a unique name. '
+ f'However, we detect that {obj} has a used name "{name}". \n'
+ f'If you try to run multiple trials, you may need \n\n'
+ f'>>> brainpy.brainpy_object.clear_name_cache() \n\n'
+ f'to clear all cached names. '
+ )
+ else:
+ _name2id[name] = id(obj)
+
+
+def get_unique_name(type_):
+ """Get the unique name for the given object type."""
+ if type_ not in _typed_names:
+ _typed_names[type_] = 0
+ name = f'{type_}{_typed_names[type_]}'
+ _typed_names[type_] += 1
+ return name
+
+
+def clear_name_cache(ignore_warn=False):
+ """Clear the cached names."""
+ _name2id.clear()
+ _typed_names.clear()
+ if not ignore_warn:
+ logger.warning(f'All named models and their ids are cleared.')
+
+
+class BrainPyObject(object):
+ """The BrainPyObject class for whole BrainPy ecosystem.
+
+ The subclass of BrainPyObject includes:
+
+ - ``DynamicalSystem`` in *brainpy.dyn.base_object.py*
+ - ``Integrator`` in *brainpy.integrators.base_object.py*
+ - ``FunAsObject`` in *brainpy.brainpy_object.function.py*
+ - ``Optimizer`` in *brainpy.optimizers.py*
+ - ``Scheduler`` in *brainpy.optimizers.py*
+ - and others.
+ """
+
+ _excluded_vars = ()
+
+ def __init__(self, name=None):
+ # check whether the object has a unique name.
+ self._name = None
+ self._name = self.unique_name(name=name)
+ check_name_uniqueness(name=self._name, obj=self)
+
+ # Used to wrap the implicit variables
+ # which cannot be accessed by self.xxx
+ self.implicit_vars = ArrayCollector()
+
+ # Used to wrap the implicit children nodes
+ # which cannot be accessed by self.xxx
+ self.implicit_nodes = Collector()
+
+ def __setattr__(self, key, value) -> None:
+ """Overwrite __setattr__ method for non-changeable Variable setting.
+
+ .. versionadded:: 2.3.1
+
+ Parameters
+ ----------
+ key: str
+ value: Any
+ """
+ if key in self.__dict__:
+ val = self.__dict__[key]
+ if isinstance(val, Variable):
+ val.value = value
+ return
+ super().__setattr__(key, value)
+
+ def tree_flatten(self):
+ """
+ .. versionadded:: 2.3.1
+
+ Returns
+ -------
+
+ """
+ dynamic_names = []
+ dynamic_values = []
+ static_names = []
+ static_values = []
+ for k, v in self.__dict__.items():
+ if isinstance(v, (ArrayCollector, BrainPyObject, Variable)):
+ dynamic_names.append(k)
+ dynamic_values.append(v)
+ else:
+ static_values.append(v)
+ static_names.append(k)
+ return tuple(dynamic_values), (tuple(dynamic_names),
+ tuple(static_names),
+ tuple(static_values))
+
+ @classmethod
+ def tree_unflatten(cls, aux, dynamic_values):
+ """
+
+ .. versionadded:: 2.3.1
+
+ Parameters
+ ----------
+ aux
+ dynamic_values
+
+ Returns
+ -------
+
+ """
+ self = cls.__new__(cls)
+ dynamic_names, static_names, static_values = aux
+ for name, value in zip(dynamic_names, dynamic_values):
+ object.__setattr__(self, name, value)
+ for name, value in zip(static_names, static_values):
+ object.__setattr__(self, name, value)
+ return self
+
+ @property
+ def name(self):
+ """Name of the model."""
+ return self._name
+
+ @name.setter
+ def name(self, name: str = None):
+ self._name = self.unique_name(name=name)
+ check_name_uniqueness(name=self._name, obj=self)
+
+ def register_implicit_vars(self, *variables, **named_variables):
+ for variable in variables:
+ if isinstance(variable, Variable):
+ self.implicit_vars[f'var{id(variable)}'] = variable
+ elif isinstance(variable, (tuple, list)):
+ for v in variable:
+ if not isinstance(v, Variable):
+ raise ValueError(f'Must be instance of {Variable.__name__}, but we got {type(v)}')
+ self.implicit_vars[f'var{id(v)}'] = v
+ elif isinstance(variable, dict):
+ for k, v in variable.items():
+ if not isinstance(v, Variable):
+ raise ValueError(f'Must be instance of {Variable.__name__}, but we got {type(v)}')
+ self.implicit_vars[k] = v
+ else:
+ raise ValueError(f'Unknown type: {type(variable)}')
+ for key, variable in named_variables.items():
+ if not isinstance(variable, Variable):
+ raise ValueError(f'Must be instance of {Variable.__name__}, but we got {type(variable)}')
+ self.implicit_vars[key] = variable
+
+ def register_implicit_nodes(self, *nodes, node_cls: type = None, **named_nodes):
+ if node_cls is None:
+ node_cls = BrainPyObject
+ for node in nodes:
+ if isinstance(node, node_cls):
+ self.implicit_nodes[node.name] = node
+ elif isinstance(node, (tuple, list)):
+ for n in node:
+ if not isinstance(n, node_cls):
+ raise ValueError(f'Must be instance of {node_cls.__name__}, but we got {type(n)}')
+ self.implicit_nodes[n.name] = n
+ elif isinstance(node, dict):
+ for k, n in node.items():
+ if not isinstance(n, node_cls):
+ raise ValueError(f'Must be instance of {node_cls.__name__}, but we got {type(n)}')
+ self.implicit_nodes[k] = n
+ else:
+ raise ValueError(f'Unknown type: {type(node)}')
+ for key, node in named_nodes.items():
+ if not isinstance(node, node_cls):
+ raise ValueError(f'Must be instance of {node_cls.__name__}, but we got {type(node)}')
+ self.implicit_nodes[key] = node
+
+ def vars(self,
+ method: str = 'absolute',
+ level: int = -1,
+ include_self: bool = True,
+ exclude_types: Tuple[type, ...] = None):
+ """Collect all variables in this node and the children nodes.
+
+ Parameters
+ ----------
+ method : str
+ The method to access the variables.
+ level: int
+ The hierarchy level to find variables.
+ include_self: bool
+ Whether include the variables in the self.
+ exclude_types: tuple of type
+ The type to exclude.
+
+ Returns
+ -------
+ gather : ArrayCollector
+ The collection contained (the path, the variable).
+ """
+ if exclude_types is None:
+ exclude_types = (VariableView,)
+ nodes = self.nodes(method=method, level=level, include_self=include_self)
+ gather = ArrayCollector()
+ for node_path, node in nodes.items():
+ for k in dir(node):
+ v = getattr(node, k)
+ include = False
+ if isinstance(v, Variable):
+ include = True
+ if len(exclude_types) and isinstance(v, exclude_types):
+ include = False
+ if include:
+ if k not in node._excluded_vars:
+ gather[f'{node_path}.{k}' if node_path else k] = v
+ gather.update({f'{node_path}.{k}': v for k, v in node.implicit_vars.items()})
+ return gather
+
+ def train_vars(self, method='absolute', level=-1, include_self=True):
+ """The shortcut for retrieving all trainable variables.
+
+ Parameters
+ ----------
+ method : str
+ The method to access the variables. Support 'absolute' and 'relative'.
+ level: int
+ The hierarchy level to find TrainVar instances.
+ include_self: bool
+ Whether include the TrainVar instances in the self.
+
+ Returns
+ -------
+ gather : ArrayCollector
+ The collection contained (the path, the trainable variable).
+ """
+ return self.vars(method=method, level=level, include_self=include_self).subset(TrainVar)
+
+ def _find_nodes(self, method='absolute', level=-1, include_self=True, _lid=0, _paths=None):
+ if _paths is None:
+ _paths = set()
+ gather = Collector()
+ if include_self:
+ if method == 'absolute':
+ gather[self.name] = self
+ elif method == 'relative':
+ gather[''] = self
+ else:
+ raise ValueError(f'No support for the method of "{method}".')
+ if (level > -1) and (_lid >= level):
+ return gather
+ if method == 'absolute':
+ nodes = []
+ for k, v in self.__dict__.items():
+ if isinstance(v, BrainPyObject):
+ path = (id(self), id(v))
+ if path not in _paths:
+ _paths.add(path)
+ gather[v.name] = v
+ nodes.append(v)
+ for node in self.implicit_nodes.values():
+ path = (id(self), id(node))
+ if path not in _paths:
+ _paths.add(path)
+ gather[node.name] = node
+ nodes.append(node)
+ for v in nodes:
+ gather.update(v._find_nodes(method=method,
+ level=level,
+ _lid=_lid + 1,
+ _paths=_paths,
+ include_self=include_self))
+
+ elif method == 'relative':
+ nodes = []
+ for k, v in self.__dict__.items():
+ if isinstance(v, BrainPyObject):
+ path = (id(self), id(v))
+ if path not in _paths:
+ _paths.add(path)
+ gather[k] = v
+ nodes.append((k, v))
+ for key, node in self.implicit_nodes.items():
+ path = (id(self), id(node))
+ if path not in _paths:
+ _paths.add(path)
+ gather[key] = node
+ nodes.append((key, node))
+ for k1, v1 in nodes:
+ for k2, v2 in v1._find_nodes(method=method,
+ _paths=_paths,
+ _lid=_lid + 1,
+ level=level,
+ include_self=include_self).items():
+ if k2: gather[f'{k1}.{k2}'] = v2
+
+ else:
+ raise ValueError(f'No support for the method of "{method}".')
+ return gather
+
+ def nodes(self, method='absolute', level=-1, include_self=True):
+ """Collect all children nodes.
+
+ Parameters
+ ----------
+ method : str
+ The method to access the nodes.
+ level: int
+ The hierarchy level to find nodes.
+ include_self: bool
+ Whether include the self.
+
+ Returns
+ -------
+ gather : Collector
+ The collection contained (the path, the node).
+ """
+ return self._find_nodes(method=method, level=level, include_self=include_self)
+
+ def unique_name(self, name=None, type_=None):
+ """Get the unique name for this object.
+
+ Parameters
+ ----------
+ name : str, optional
+ The expected name. If None, the default unique name will be returned.
+ Otherwise, the provided name will be checked to guarantee its uniqueness.
+ type_ : str, optional
+ The name of this class, used for object naming.
+
+ Returns
+ -------
+ name : str
+ The unique name for this object.
+ """
+ if name is None:
+ if type_ is None:
+ return get_unique_name(type_=self.__class__.__name__)
+ else:
+ return get_unique_name(type_=type_)
+ else:
+ check_name_uniqueness(name=name, obj=self)
+ return name
+
+ def state_dict(self):
+ """Returns a dictionary containing a whole state of the module.
+
+ Returns
+ -------
+ out: dict
+ A dictionary containing a whole state of the module.
+ """
+ return self.vars().unique().dict()
+
+ def load_state_dict(self, state_dict: Dict[str, Any], warn: bool = True):
+ """Copy parameters and buffers from :attr:`state_dict` into
+ this module and its descendants.
+
+ Parameters
+ ----------
+ state_dict: dict
+ A dict containing parameters and persistent buffers.
+ warn: bool
+ Warnings when there are missing keys or unexpected keys in the external ``state_dict``.
+
+ Returns
+ -------
+ out: StateLoadResult
+ ``NamedTuple`` with ``missing_keys`` and ``unexpected_keys`` fields:
+
+ * **missing_keys** is a list of str containing the missing keys
+ * **unexpected_keys** is a list of str containing the unexpected keys
+ """
+ variables = self.vars().unique()
+ keys1 = set(state_dict.keys())
+ keys2 = set(variables.keys())
+ unexpected_keys = list(keys1 - keys2)
+ missing_keys = list(keys2 - keys1)
+ for key in keys2.intersection(keys1):
+ variables[key].value = state_dict[key]
+ if warn:
+ if len(unexpected_keys):
+ warnings.warn(f'Unexpected keys in state_dict: {unexpected_keys}', UserWarning)
+ if len(missing_keys):
+ warnings.warn(f'Missing keys in state_dict: {missing_keys}', UserWarning)
+ return StateLoadResult(missing_keys, unexpected_keys)
+
+ def load_states(self, filename, verbose=False):
+ """Load the model states.
+
+ Parameters
+ ----------
+ filename : str
+ The filename which stores the model states.
+ verbose: bool
+ Whether report the load progress.
+ """
+ from brainpy.checkpoints import io
+ if not os.path.exists(filename):
+ raise errors.BrainPyError(f'Cannot find the file path: {filename}')
+ elif filename.endswith('.hdf5') or filename.endswith('.h5'):
+ io.load_by_h5(filename, target=self, verbose=verbose)
+ elif filename.endswith('.pkl'):
+ io.load_by_pkl(filename, target=self, verbose=verbose)
+ elif filename.endswith('.npz'):
+ io.load_by_npz(filename, target=self, verbose=verbose)
+ elif filename.endswith('.mat'):
+ io.load_by_mat(filename, target=self, verbose=verbose)
+ else:
+ raise errors.BrainPyError(f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}')
+
+ def save_states(self, filename, variables=None, **setting):
+ """Save the model states.
+
+ Parameters
+ ----------
+ filename : str
+ The file name which to store the model states.
+ variables: optional, dict, ArrayCollector
+ The variables to save. If not provided, all variables retrieved by ``~.vars()`` will be used.
+ """
+ if variables is None:
+ variables = self.vars(method='absolute', level=-1)
+
+ from brainpy.checkpoints import io
+ if filename.endswith('.hdf5') or filename.endswith('.h5'):
+ io.save_as_h5(filename, variables=variables)
+ elif filename.endswith('.pkl') or filename.endswith('.pickle'):
+ io.save_as_pkl(filename, variables=variables)
+ elif filename.endswith('.npz'):
+ io.save_as_npz(filename, variables=variables, **setting)
+ elif filename.endswith('.mat'):
+ io.save_as_mat(filename, variables=variables)
+ else:
+ raise errors.BrainPyError(f'Unknown file format: {filename}. We only supports {io.SUPPORTED_FORMATS}')
+
+ # def to(self, devices):
+ # global math
+ # if math is None: from brainpy import math
+ #
+ # def cpu(self):
+ # global math
+ # if math is None: from brainpy import math
+ #
+ # all_vars = self.vars().unique()
+ # for data in all_vars.values():
+ # data[:] = asarray(data.value)
+ #
+ # def cuda(self):
+ # global math
+ # if math is None: from brainpy import math
+ #
+ # def tpu(self):
+ # global math
+ # if math is None: from brainpy import math
+
+
+Base = BrainPyObject
+
+
+class FunAsObject(BrainPyObject):
+ """Transform a Python function as a :py:class:`~.BrainPyObject`.
+
+ Parameters
+ ----------
+ f : callable
+ The function to wrap.
+ child_objs : optional, BrainPyObject, sequence of BrainPyObject, dict
+ The nodes in the defined function ``f``.
+ dyn_vars : optional, Variable, sequence of Variable, dict
+ The dynamically changed variables.
+ name : optional, str
+ The function name.
+ """
+
+ def __init__(
+ self,
+ f: Callable,
+ child_objs: Union[BrainPyObject, Sequence[BrainPyObject], Dict[dict, BrainPyObject]] = None,
+ dyn_vars: Union[Variable, Sequence[Variable], Dict[dict, Variable]] = None,
+ name: str = None
+ ):
+ super(FunAsObject, self).__init__(name=name)
+ self._f = f
+ if child_objs is not None:
+ self.register_implicit_nodes(child_objs)
+ if dyn_vars is not None:
+ self.register_implicit_vars(dyn_vars)
+
+ def __call__(self, *args, **kwargs):
+ return self._f(*args, **kwargs)
+
+ def __repr__(self) -> str:
+ from brainpy.tools import repr_context
+ name = self.__class__.__name__
+ indent = " " * (len(name) + 1)
+ indent2 = indent + " " * len('nodes=')
+ nodes = [repr_context(str(n), indent2) for n in self.implicit_nodes.values()]
+ node_string = ", \n".join(nodes)
+ return (f'{name}(nodes=[{node_string}],\n' +
+ " " * (len(name) + 1) + f'num_of_vars={len(self.implicit_vars)})')
diff --git a/brainpy/math/object_transform/base.py b/brainpy/math/object_transform/base_transform.py
similarity index 89%
rename from brainpy/math/object_transform/base.py
rename to brainpy/math/object_transform/base_transform.py
index cc9d9d8da..4a5301ea6 100644
--- a/brainpy/math/object_transform/base.py
+++ b/brainpy/math/object_transform/base_transform.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
-from brainpy.base.base import BrainPyObject
+from .base_object import BrainPyObject
__all__ = [
'ObjectTransform'
diff --git a/brainpy/math/object_transform/collector.py b/brainpy/math/object_transform/collector.py
new file mode 100644
index 000000000..ecc24d820
--- /dev/null
+++ b/brainpy/math/object_transform/collector.py
@@ -0,0 +1,224 @@
+# -*- coding: utf-8 -*-
+
+
+from typing import Dict, Sequence, Union
+
+from jax.tree_util import register_pytree_node
+from jax.util import safe_zip
+
+from ..ndarray import Array
+
+
+__all__ = [
+ 'Collector',
+ 'ArrayCollector',
+ 'TensorCollector',
+]
+
+
+class Collector(dict):
+ """A Collector is a dictionary (name, var) with some additional methods to make manipulation
+ of collections of variables easy. A Collector is ordered by insertion order. It is the object
+ returned by BrainPyObject.vars() and used as input in many Collector instance: optimizers, jit, etc..."""
+
+ def __setitem__(self, key, value):
+ """Overload bracket assignment to catch potential conflicts during assignment."""
+ if key in self:
+ if id(self[key]) != id(value):
+ raise ValueError(f'Name "{key}" conflicts: same name for {value} and {self[key]}.')
+ dict.__setitem__(self, key, value)
+
+ def replace(self, key, new_value):
+ """Replace the original key with the new value."""
+ self.pop(key)
+ self[key] = new_value
+
+ def update(self, other, **kwargs):
+ assert isinstance(other, (dict, list, tuple))
+ if isinstance(other, dict):
+ for key, value in other.items():
+ self[key] = value
+ elif isinstance(other, (tuple, list)):
+ num = len(self)
+ for i, value in enumerate(other):
+ self[f'_var{i + num}'] = value
+ else:
+ raise ValueError(f'Only supports dict/list/tuple, but we got {type(other)}')
+ for key, value in kwargs.items():
+ self[key] = value
+ return self
+
+ def __add__(self, other):
+ """Merging two dicts.
+
+ Parameters
+ ----------
+ other: dict
+ The other dict instance.
+
+ Returns
+ -------
+ gather: Collector
+ The new collector.
+ """
+ gather = type(self)(self)
+ gather.update(other)
+ return gather
+
+ def __sub__(self, other: Union[Dict, Sequence]):
+ """Remove other item in the collector.
+
+ Parameters
+ ----------
+ other: dict, sequence
+ The items to remove.
+
+ Returns
+ -------
+ gather: Collector
+ The new collector.
+ """
+ if not isinstance(other, (dict, tuple, list)):
+ raise ValueError(f'Only support dict/tuple/list, but we got {type(other)}.')
+ gather = type(self)(self)
+ if isinstance(other, dict):
+ for key, val in other.items():
+ if key in gather:
+ if id(val) != id(gather[key]):
+ raise ValueError(f'Cannot remove {key}, because we got two different values: '
+ f'{val} != {gather[key]}')
+ gather.pop(key)
+ else:
+ raise ValueError(f'Cannot remove {key}, because we do not find it '
+ f'in {self.keys()}.')
+ elif isinstance(other, (list, tuple)):
+ id_to_keys = {}
+ for k, v in self.items():
+ id_ = id(v)
+ if id_ not in id_to_keys:
+ id_to_keys[id_] = []
+ id_to_keys[id_].append(k)
+
+ keys_to_remove = []
+ for key in other:
+ if isinstance(key, str):
+ keys_to_remove.append(key)
+ else:
+ keys_to_remove.extend(id_to_keys[id(key)])
+
+ for key in set(keys_to_remove):
+ if key in gather:
+ gather.pop(key)
+ else:
+ raise ValueError(f'Cannot remove {key}, because we do not find it '
+ f'in {self.keys()}.')
+ else:
+ raise KeyError(f'Unknown type of "other". Only support dict/tuple/list, but we got {type(other)}')
+ return gather
+
+ def subset(self, var_type):
+ """Get the subset of the (key, value) pair.
+
+ ``subset()`` can be used to get a subset of some class:
+
+ >>> import brainpy as bp
+ >>>
+ >>> some_collector = Collector()
+ >>>
+ >>> # get all trainable variables
+ >>> some_collector.subset(bp.math.TrainVar)
+ >>>
+ >>> # get all Variable
+ >>> some_collector.subset(bp.math.Variable)
+
+ or, it can be used to get a subset of integrators:
+
+ >>> # get all ODE integrators
+ >>> some_collector.subset(bp.ode.ODEIntegrator)
+
+ Parameters
+ ----------
+ var_type : type
+ The type/class to match.
+ """
+ gather = type(self)()
+ for key, value in self.items():
+ if isinstance(value, var_type):
+ gather[key] = value
+ return gather
+
+ def unique(self):
+ """Get a new type of collector with unique values.
+
+ If one value is assigned to two or more keys,
+ then only one pair of (key, value) will be returned.
+ """
+ gather = type(self)()
+ seen = set()
+ for k, v in self.items():
+ if id(v) not in seen:
+ seen.add(id(v))
+ gather[k] = v
+ return gather
+
+
+class ArrayCollector(Collector):
+ """A ArrayCollector is a dictionary (name, var)
+ with some additional methods to make manipulation
+ of collections of variables easy. A Collection
+ is ordered by insertion order. It is the object
+ returned by DynamicalSystem.vars() and used as input
+ in many DynamicalSystem instance: optimizers, Jit, etc..."""
+
+ def __setitem__(self, key, value):
+ """Overload bracket assignment to catch potential conflicts during assignment."""
+
+ assert isinstance(value, Array)
+ if key in self:
+ if id(self[key]) != id(value):
+ raise ValueError(f'Name "{key}" conflicts: same name for {value} and {self[key]}.')
+ dict.__setitem__(self, key, value)
+
+ def assign(self, inputs):
+ """Assign data to all values.
+
+ Parameters
+ ----------
+ inputs : dict
+ The data for each value in this collector.
+ """
+ if len(self) != len(inputs):
+ raise ValueError(f'The target has {len(inputs)} data, while we got '
+ f'an input value with the length of {len(inputs)}.')
+ for key, v in self.items():
+ v.value = inputs[key]
+
+ def dict(self):
+ """Get a dict with the key and the value data.
+ """
+ gather = dict()
+ for k, v in self.items():
+ gather[k] = v.value
+ return gather
+
+ def data(self):
+ """Get all data in each value."""
+ return [x.value for x in self.values()]
+
+ @classmethod
+ def from_other(cls, other: Union[Sequence, Dict]):
+ if isinstance(other, (tuple, list)):
+ return cls({id(o): o for o in other})
+ elif isinstance(other, dict):
+ return cls(other)
+ else:
+ raise TypeError
+
+
+TensorCollector = ArrayCollector
+
+register_pytree_node(
+ ArrayCollector,
+ lambda x: (x.values(), x.keys()),
+ lambda keys, values: ArrayCollector(safe_zip(keys, values))
+)
diff --git a/brainpy/math/object_transform/controls.py b/brainpy/math/object_transform/controls.py
index 6615cc7d0..ba4bd3594 100644
--- a/brainpy/math/object_transform/controls.py
+++ b/brainpy/math/object_transform/controls.py
@@ -13,14 +13,14 @@
from jax.core import UnexpectedTracerError
from brainpy import errors, tools, check
-from brainpy.base.naming import get_unique_name
-from brainpy.base import ArrayCollector
-from brainpy.math.ndarray import (Array, Variable,
+from .base_object import get_unique_name, BrainPyObject
+from .collector import ArrayCollector
+from ..ndarray import (Array, Variable,
add_context,
del_context)
-from brainpy.math.numpy_ops import as_device_array
+from ..numpy_ops import as_jax
from ._utils import infer_dyn_vars
-from .base import ObjectTransform
+from .base_transform import ObjectTransform
__all__ = [
'make_loop',
@@ -298,7 +298,7 @@ def _body_fun(op):
def _cond_fun(op):
dyn_values, static_values = op
for v, d in zip(dyn_vars, dyn_values): v._value = d
- return as_device_array(cond_fun(static_values))
+ return as_jax(cond_fun(static_values))
name = get_unique_name('_brainpy_object_oriented_make_while_')
@@ -445,6 +445,7 @@ def cond(
false_fun: Union[Callable, jnp.ndarray, Array, float, int, bool],
operands: Any,
dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None,
+ child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
):
"""Simple conditional statement (if-else) with instance of :py:class:`~.Variable`.
@@ -477,6 +478,10 @@ def cond(
can be a scalar, array, or any pytree (nested Python tuple/list/dict) thereof.
dyn_vars: optional, Variable, sequence of Variable, dict
The dynamically changed variables.
+ child_objs: optional, dict, sequence of BrainPyObject, BrainPyObject
+ The children objects used in the target function.
+
+ .. versionadded:: 2.3.1
Returns
-------
@@ -487,8 +492,11 @@ def cond(
true_fun = _check_f(true_fun)
false_fun = _check_f(false_fun)
dyn_vars = check.is_all_vars(dyn_vars, out_as='dict')
+ dyn_vars = ArrayCollector(dyn_vars)
dyn_vars.update(infer_dyn_vars(true_fun))
dyn_vars.update(infer_dyn_vars(false_fun))
+ for obj in check.is_all_objs(child_objs, out_as='tuple'):
+ dyn_vars.update(obj.vars().unique())
dyn_vars = list(ArrayCollector(dyn_vars).unique().values())
name = get_unique_name('_brainpy_object_oriented_cond_')
@@ -539,6 +547,7 @@ def ifelse(
branches: Sequence[Callable],
operands: Any = None,
dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None,
+ child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
show_code: bool = False,
):
"""``If-else`` control flows looks like native Pythonic programming.
@@ -578,6 +587,10 @@ def ifelse(
The dynamically changed variables.
show_code: bool
Whether show the formatted code.
+ child_objs: optional, dict, sequence of BrainPyObject, BrainPyObject
+ The children objects used in the target function.
+
+ .. versionadded:: 2.3.1
Returns
-------
@@ -602,7 +615,9 @@ def ifelse(
dyn_vars = ArrayCollector(dyn_vars)
for f in branches:
dyn_vars += infer_dyn_vars(f)
- dyn_vars = tuple(dyn_vars.values())
+ for obj in check.is_all_objs(child_objs, out_as='tuple'):
+ dyn_vars.update(obj.vars().unique())
+ dyn_vars = tuple(dyn_vars.unique().values())
# format new codes
if len(conditions) == 1:
@@ -647,6 +662,7 @@ def for_loop(
operands: Any,
dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None,
out_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
+ child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
reverse: bool = False,
unroll: int = 1,
):
@@ -727,6 +743,10 @@ def for_loop(
Optional positive int specifying, in the underlying operation of the
scan primitive, how many scan iterations to unroll within a single
iteration of a loop.
+ child_objs: optional, dict, sequence of BrainPyObject, BrainPyObject
+ The children objects used in the target function.
+
+ .. versionadded:: 2.3.1
Returns
-------
@@ -734,7 +754,10 @@ def for_loop(
The stacked outputs of ``body_fun`` when scanned over the leading axis of the inputs.
"""
dyn_vars = check.is_all_vars(dyn_vars, out_as='dict')
+ dyn_vars = ArrayCollector(dyn_vars)
dyn_vars.update(infer_dyn_vars(body_fun))
+ for obj in check.is_all_objs(child_objs, out_as='tuple'):
+ dyn_vars.update(obj.vars().unique())
dyn_vars = list(ArrayCollector(dyn_vars).unique().values())
outs, _ = tree_flatten(out_vars, lambda s: isinstance(s, Variable))
for v in outs:
@@ -785,6 +808,7 @@ def while_loop(
cond_fun: Callable,
operands: Any,
dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None,
+ child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
):
"""``while-loop`` control flow with :py:class:`~.Variable`.
@@ -831,13 +855,18 @@ def while_loop(
The dynamically changed variables.
operands: Any
The operands for ``body_fun`` and ``cond_fun`` functions.
+ child_objs: optional, dict, sequence of BrainPyObject, BrainPyObject
+ The children objects used in the target function.
+ .. versionadded:: 2.3.1
"""
# iterable variables
dyn_vars = check.is_all_vars(dyn_vars, out_as='dict')
dyn_vars = ArrayCollector(dyn_vars)
dyn_vars.update(infer_dyn_vars(body_fun))
dyn_vars.update(infer_dyn_vars(cond_fun))
+ for obj in check.is_all_objs(child_objs, out_as='tuple'):
+ dyn_vars.update(obj.vars().unique())
dyn_vars = tuple(dyn_vars.values())
if not isinstance(operands, (list, tuple)):
operands = (operands,)
diff --git a/brainpy/math/object_transform/function.py b/brainpy/math/object_transform/function.py
index e0bbe2f1c..0912e0946 100644
--- a/brainpy/math/object_transform/function.py
+++ b/brainpy/math/object_transform/function.py
@@ -1,9 +1,10 @@
# -*- coding: utf-8 -*-
+
import warnings
from typing import Union, Sequence, Dict, Callable
-from brainpy.base import FunAsObject, BrainPyObject
-from brainpy.math.ndarray import Variable
+from .base_object import FunAsObject, BrainPyObject
+from ..ndarray import Variable
__all__ = [
'to_object',
@@ -17,7 +18,7 @@ def to_object(
child_objs: Union[Callable, BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]] = None,
dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None,
name: str = None
-) -> BrainPyObject:
+):
"""Transform a Python function to :py:class:`~.BrainPyObject`.
Parameters
@@ -54,7 +55,7 @@ def to_dynsys(
child_objs: Union[Callable, BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]] = None,
dyn_vars: Union[Variable, Sequence[Variable], Dict[str, Variable]] = None,
name: str = None
-) -> BrainPyObject:
+):
"""Transform a Python function to a :py:class:`~.DynamicalSystem`.
Parameters
@@ -76,8 +77,9 @@ def to_dynsys(
from brainpy.dyn.base import FuncAsDynSys
if f is None:
- return lambda func: FuncAsDynSys(f=func, child_objs=child_objs, dyn_vars=dyn_vars, name=name)
-
+ def wrap(func) -> FuncAsDynSys:
+ return FuncAsDynSys(f=func, child_objs=child_objs, dyn_vars=dyn_vars, name=name)
+ return wrap
else:
if child_objs is None:
raise ValueError(f'"child_objs" cannot be None when "f" is provided.')
diff --git a/brainpy/math/object_transform/jit.py b/brainpy/math/object_transform/jit.py
index 4195dbad3..c2bc07ad7 100644
--- a/brainpy/math/object_transform/jit.py
+++ b/brainpy/math/object_transform/jit.py
@@ -17,9 +17,9 @@
from jax.core import UnexpectedTracerError, ConcretizationTypeError
from brainpy import errors, tools, check
-from brainpy.base import BrainPyObject
-from brainpy.math.ndarray import Variable, add_context, del_context
-from .base import ObjectTransform
+from .base_transform import ObjectTransform
+from .base_object import BrainPyObject
+from ..ndarray import Variable, add_context, del_context
__all__ = [
'jit',
diff --git a/brainpy/math/object_transform/parallels.py b/brainpy/math/object_transform/parallels.py
index 5dfa6aaad..0a2152d00 100644
--- a/brainpy/math/object_transform/parallels.py
+++ b/brainpy/math/object_transform/parallels.py
@@ -24,10 +24,9 @@
from jax.core import UnexpectedTracerError
from brainpy import errors
-from brainpy.base.base import BrainPyObject
-from brainpy.base.collector import ArrayCollector
-from brainpy.math.random import RandomState
-from brainpy.math.ndarray import Array
+from .base_object import BrainPyObject, ArrayCollector
+from ..random import RandomState
+from ..ndarray import Array
from brainpy.tools.codes import change_func_name
__all__ = [
diff --git a/brainpy/base/tests/test_base.py b/brainpy/math/object_transform/tests/test_base.py
similarity index 100%
rename from brainpy/base/tests/test_base.py
rename to brainpy/math/object_transform/tests/test_base.py
diff --git a/brainpy/base/tests/test_circular_reference.py b/brainpy/math/object_transform/tests/test_circular_reference.py
similarity index 100%
rename from brainpy/base/tests/test_circular_reference.py
rename to brainpy/math/object_transform/tests/test_circular_reference.py
diff --git a/brainpy/base/tests/test_collector.py b/brainpy/math/object_transform/tests/test_collector.py
similarity index 99%
rename from brainpy/base/tests/test_collector.py
rename to brainpy/math/object_transform/tests/test_collector.py
index deda39ea5..a7189bd21 100644
--- a/brainpy/base/tests/test_collector.py
+++ b/brainpy/math/object_transform/tests/test_collector.py
@@ -272,7 +272,7 @@ def test_net_vars_2():
def test_hidden_variables():
- class BPClass(bp.base.BrainPyObject):
+ class BPClass(bp.BrainPyObject):
_excluded_vars = ('_rng_', )
def __init__(self):
diff --git a/brainpy/base/tests/test_namechecking.py b/brainpy/math/object_transform/tests/test_namechecking.py
similarity index 100%
rename from brainpy/base/tests/test_namechecking.py
rename to brainpy/math/object_transform/tests/test_namechecking.py
diff --git a/brainpy/math/operators/op_register.py b/brainpy/math/operators/op_register.py
index 7fde6ddd6..08789c387 100644
--- a/brainpy/math/operators/op_register.py
+++ b/brainpy/math/operators/op_register.py
@@ -6,8 +6,8 @@
import brainpylib
from jax.tree_util import tree_map
-from brainpy.base import BrainPyObject
-from brainpy.math.ndarray import Array
+from ..object_transform.base_object import BrainPyObject
+from ..ndarray import Array
__all__ = [
'XLACustomOp',
diff --git a/brainpy/math/operators/pre_syn_post.py b/brainpy/math/operators/pre_syn_post.py
index 101710312..c65d9be95 100644
--- a/brainpy/math/operators/pre_syn_post.py
+++ b/brainpy/math/operators/pre_syn_post.py
@@ -8,7 +8,6 @@
from brainpy.errors import MathError
from brainpy.math.numpy_ops import as_jax
-from brainpy.types import ArrayType
__all__ = [
# pre-to-post
@@ -43,10 +42,10 @@ def _raise_pre_ids_is_none(pre_ids):
f'(brainpy.math.ndim(pre_values) != 0).')
-def pre2post_event_sum(events: ArrayType,
- pre2post: Tuple[ArrayType, ArrayType],
+def pre2post_event_sum(events,
+ pre2post,
post_num: int,
- values: Union[float, ArrayType] = 1.):
+ values = 1.):
"""The pre-to-post event-driven synaptic summation with `CSR` synapse structure.
When ``values`` is a scalar, this function is equivalent to
@@ -103,11 +102,11 @@ def pre2post_event_sum(events: ArrayType,
transpose=True)
-def pre2post_coo_event_sum(events: ArrayType,
- pre_ids: ArrayType,
- post_ids: ArrayType,
+def pre2post_coo_event_sum(events,
+ pre_ids,
+ post_ids,
post_num: int,
- values: Union[float, ArrayType] = 1.):
+ values = 1.):
"""The pre-to-post synaptic computation with event-driven summation.
Parameters
diff --git a/brainpy/math/operators/sparse_matmul.py b/brainpy/math/operators/sparse_matmul.py
index 6feae9678..8418a4aaf 100644
--- a/brainpy/math/operators/sparse_matmul.py
+++ b/brainpy/math/operators/sparse_matmul.py
@@ -9,7 +9,6 @@
from brainpy.math.ndarray import Array
from brainpy.math.numpy_ops import as_jax
-from brainpy.types import ArrayType
__all__ = [
'sparse_matmul',
@@ -18,10 +17,10 @@
]
-def event_csr_matvec(values: ArrayType,
- indices: ArrayType,
- indptr: ArrayType,
- events: ArrayType,
+def event_csr_matvec(values,
+ indices,
+ indptr,
+ events,
shape: Tuple[int, int],
transpose: bool = False):
"""The pre-to-post event-driven synaptic summation with `CSR` synapse structure.
diff --git a/brainpy/math/operators/tests/test_op_register.py b/brainpy/math/operators/tests/test_op_register.py
index 4921e291d..6f4b8ab47 100644
--- a/brainpy/math/operators/tests/test_op_register.py
+++ b/brainpy/math/operators/tests/test_op_register.py
@@ -7,6 +7,8 @@
import brainpy as bp
import brainpy.math as bm
+
+bm.random.seed()
bm.set_platform('cpu')
diff --git a/brainpy/math/others.py b/brainpy/math/others.py
index 6e140d6dd..ddf398f84 100644
--- a/brainpy/math/others.py
+++ b/brainpy/math/others.py
@@ -9,16 +9,16 @@
from .environment import get_dt, get_int
__all__ = [
- 'form_shared_args'
+ 'shared_args_over_time'
]
-def form_shared_args(num_step: Optional[int] = None,
- duration: Optional[float] = None,
- dt: Optional[float] = None,
- t0: float = 0.,
- include_dt: bool = True):
- """Form a shared argument for the inference of a :py:class:`~.DynamicalSystem`.
+def shared_args_over_time(num_step: Optional[int] = None,
+ duration: Optional[float] = None,
+ dt: Optional[float] = None,
+ t0: float = 0.,
+ include_dt: bool = True):
+ """Form a shared argument over time for the inference of a :py:class:`~.DynamicalSystem`.
Parameters
----------
diff --git a/brainpy/math/random.py b/brainpy/math/random.py
index ffead3a19..ccebc533b 100644
--- a/brainpy/math/random.py
+++ b/brainpy/math/random.py
@@ -14,14 +14,13 @@
from jax.tree_util import register_pytree_node
from brainpy.check import jit_error_checking
-from brainpy.errors import UnsupportedError
from brainpy.math.ndarray import Array, Variable
from ._utils import wraps
__all__ = [
'RandomState', 'Generator',
- 'seed', 'default_rng',
+ 'seed', 'default_rng', 'get_rng',
'rand', 'randint', 'random_integers', 'randn', 'random',
'random_sample', 'ranf', 'sample', 'choice', 'permutation', 'shuffle', 'beta',
@@ -430,6 +429,7 @@ def __init__(self,
'seed will be removed since 2.4.0', UserWarning)
if seed_or_key is None:
+ # key = DEFAULT.split_key()
seed_or_key = np.random.randint(0, 100000, 2, dtype=np.uint32)
if isinstance(seed_or_key, int):
key = jr.PRNGKey(seed_or_key)
@@ -443,9 +443,6 @@ def __init__(self,
def __repr__(self) -> str:
print_code = repr(self.value)
i = print_code.index('(')
-
- # if 'DeviceArray' in print_code:
- # print_code = print_code.replace('DeviceArray', '')
name = self.__class__.__name__
return f'{name}(key={print_code[i:]})'
@@ -514,9 +511,6 @@ def split_keys(self, n):
self._value = keys[0]
return keys[1:]
- def update(self, value):
- raise UnsupportedError(f'Do not support change the value of a {self.__class__.__name__}.')
-
# ---------------- #
# random functions #
# ---------------- #
@@ -1134,9 +1128,16 @@ def noncentral_f(self, dfnum, dfden, nonc, size=None, key=None):
del __a
+def get_rng(seed_or_key=None, clone: bool = True) -> RandomState:
+ if seed_or_key is None:
+ return DEFAULT.clone() if clone else DEFAULT
+ else:
+ return RandomState(seed_or_key)
+
+
@wraps(np.random.default_rng)
-def default_rng(seed=None):
- return RandomState(seed)
+def default_rng(seed_or_key=None):
+ return RandomState(seed_or_key)
@wraps(np.random.seed)
diff --git a/brainpy/math/surrogate/tests/test_compat.py b/brainpy/math/surrogate/tests/test_compat.py
index 17d8a33a6..79484c569 100644
--- a/brainpy/math/surrogate/tests/test_compat.py
+++ b/brainpy/math/surrogate/tests/test_compat.py
@@ -5,6 +5,7 @@
from functools import partial
import brainpy.math as bm
+bm.random.seed()
def test_sp_sigmoid_grad():
diff --git a/brainpy/math/tests/test_numpy_ops.py b/brainpy/math/tests/test_numpy_ops.py
index d167a9e01..60615656e 100644
--- a/brainpy/math/tests/test_numpy_ops.py
+++ b/brainpy/math/tests/test_numpy_ops.py
@@ -5557,7 +5557,7 @@ def testLinspaceEndpoints(self, dtype):
dtype.__name__ if dtype else "None"),
"start_shape": start_shape,
"stop_shape": stop_shape,
- "num": num, "endpoint": endpoint, "base": base,
+ "num": num, "endpoint": endpoint, "brainpy_object": base,
"dtype": dtype}
for start_shape in [(), (2,), (2, 2)]
for stop_shape in [(), (2,), (2, 2)]
diff --git a/brainpy/modes.py b/brainpy/modes.py
index c0dc1c32d..3505fbbfb 100644
--- a/brainpy/modes.py
+++ b/brainpy/modes.py
@@ -1,9 +1,16 @@
# -*- coding: utf-8 -*-
+"""
+This module is deprecated since version 2.3.1.
+Please use ``brainpy.math.*`` instead.
+"""
+
+
import numpy as np
import brainpy.math as bm
+
__all__ = [
'Mode',
'NormalMode',
diff --git a/brainpy/optimizers/optimizer.py b/brainpy/optimizers/optimizer.py
index f41ef41d1..fb0f8549c 100644
--- a/brainpy/optimizers/optimizer.py
+++ b/brainpy/optimizers/optimizer.py
@@ -3,11 +3,9 @@
from typing import Union, Sequence, Dict, Optional, Tuple
import jax.numpy as jnp
-from jax.lax import cond, rsqrt
+from jax.lax import cond
import brainpy.math as bm
-from brainpy.base.base import BrainPyObject
-from brainpy.base.collector import ArrayCollector
from brainpy.errors import MathError
from .scheduler import make_schedule, Scheduler
@@ -26,7 +24,7 @@
]
-class Optimizer(BrainPyObject):
+class Optimizer(bm.BrainPyObject):
"""Base Optimizer Class.
Parameters
@@ -38,7 +36,7 @@ class Optimizer(BrainPyObject):
lr: Scheduler # learning rate
'''Learning rate'''
- vars_to_train: ArrayCollector # variables to train
+ vars_to_train: bm.ArrayCollector # variables to train
'''Variables to train.'''
def __init__(
@@ -49,7 +47,7 @@ def __init__(
):
super(Optimizer, self).__init__(name=name)
self.lr: Scheduler = make_schedule(lr)
- self.vars_to_train = ArrayCollector()
+ self.vars_to_train = bm.ArrayCollector()
self.register_vars(train_vars)
def register_vars(self, train_vars: Optional[Dict[str, bm.Variable]] = None):
diff --git a/brainpy/optimizers/scheduler.py b/brainpy/optimizers/scheduler.py
index 0623b3434..e74697124 100644
--- a/brainpy/optimizers/scheduler.py
+++ b/brainpy/optimizers/scheduler.py
@@ -2,7 +2,6 @@
import jax.numpy as jnp
-from brainpy.base.base import BrainPyObject
from brainpy.errors import MathError
import brainpy.math as bm
@@ -30,7 +29,7 @@ def make_schedule(scalar_or_schedule):
raise TypeError(type(scalar_or_schedule))
-class Scheduler(BrainPyObject):
+class Scheduler(bm.BrainPyObject):
"""The learning rate scheduler."""
def __init__(self, lr):
diff --git a/brainpy/running/runner.py b/brainpy/running/runner.py
index 0923b4255..64267a29c 100644
--- a/brainpy/running/runner.py
+++ b/brainpy/running/runner.py
@@ -8,7 +8,6 @@
import numpy as np
from brainpy import math as bm, check
-from brainpy.base import BrainPyObject
from brainpy.errors import MonitorError, RunningError
from brainpy.tools import DotDict
from . import constants as C
@@ -18,7 +17,7 @@
]
-class Runner(BrainPyObject):
+class Runner(bm.BrainPyObject):
"""Base Runner.
Parameters
@@ -64,12 +63,12 @@ class Runner(BrainPyObject):
jit: Dict[str, bool]
'''Flag to denote whether to use JIT.'''
- target: BrainPyObject
+ target: bm.BrainPyObject
'''The target model to run.'''
def __init__(
self,
- target: BrainPyObject,
+ target: bm.BrainPyObject,
monitors: Union[Sequence, Dict] = None,
fun_monitors: Dict[str, Callable] = None,
jit: Union[bool, Dict[str, bool]] = True,
@@ -100,34 +99,12 @@ def __init__(
# format string monitors
monitors = self._format_seq_monitors(monitors)
# get monitor targets
- monitors = self._find_monitor_targets(monitors)
+ monitors = self._find_seq_monitor_targets(monitors)
elif isinstance(monitors, dict):
- _monitors = dict()
- for key, val in monitors.items():
- if not isinstance(key, str):
- raise MonitorError('Expect the key of the dict "monitors" must be a string. But got '
- f'{type(key)}: {key}')
- if isinstance(val, bm.Variable):
- val = (val, None)
- if isinstance(val, (tuple, list)):
- if not isinstance(val[0], bm.Variable):
- raise MonitorError('Expect the format of (variable, index) in the monitor setting. '
- f'But we got {val}')
- if len(val) == 1:
- _monitors[key] = (val[0], None)
- elif len(val) == 2:
- if isinstance(val[1], (int, np.integer)):
- idx = bm.array([val[1]])
- else:
- idx = None if val[1] is None else bm.asarray(val[1])
- _monitors[key] = (val[0], idx)
- else:
- raise MonitorError('Expect the format of (variable, index) in the monitor setting. '
- f'But we got {val}')
- else:
- raise MonitorError('Expect the format of (variable, index) in the monitor setting. '
- f'But we got {val}')
- monitors = _monitors
+ # format string monitors
+ monitors = self._format_dict_monitors(monitors)
+ # get monitor targets
+ monitors = self._find_dict_monitor_targets(monitors)
else:
raise MonitorError(f'We only supports a format of list/tuple/dict of '
f'"vars", while we got {type(monitors)}.')
@@ -136,7 +113,7 @@ def __init__(
# deprecated func_monitors
if fun_monitors is not None:
if isinstance(fun_monitors, dict):
- warnings.warn("`func_monitors` is deprecated since version 2.3.1. "
+ warnings.warn("`fun_monitors` is deprecated since version 2.3.1. "
"Define `func_monitors` in `monitors`")
check.is_dict_data(fun_monitors, key_type=str, val_type=types.FunctionType)
self._monitors.update(fun_monitors)
@@ -159,7 +136,7 @@ def __init__(
def _format_seq_monitors(self, monitors):
if not isinstance(monitors, (tuple, list)):
- raise TypeError(f'Must be a sequence, but we got {type(monitors)}')
+ raise TypeError(f'Must be a tuple/list, but we got {type(monitors)}')
_monitors = []
for mon in monitors:
if isinstance(mon, str):
@@ -182,7 +159,40 @@ def _format_seq_monitors(self, monitors):
raise MonitorError(f'We do not support monitor with {type(mon)}: {mon}')
return _monitors
- def _find_monitor_targets(self, _monitors):
+ def _format_dict_monitors(self, monitors):
+ if not isinstance(monitors, dict):
+ raise TypeError(f'Must be a dict, but we got {type(monitors)}')
+ _monitors = dict()
+ for key, val in monitors.items():
+ if not isinstance(key, str):
+ raise MonitorError('Expect the key of the dict "monitors" must be a string. But got '
+ f'{type(key)}: {key}')
+ if isinstance(val, (bm.Variable, str)):
+ val = (val, None)
+
+ if isinstance(val, (tuple, list)):
+ if not isinstance(val[0], (bm.Variable, str)):
+ raise MonitorError('Expect the format of (variable, index) in the monitor setting. '
+ f'But we got {val}')
+ if len(val) == 1:
+ _monitors[key] = (val[0], None)
+ elif len(val) == 2:
+ if isinstance(val[1], (int, np.integer)):
+ idx = bm.array([val[1]])
+ else:
+ idx = None if val[1] is None else bm.asarray(val[1])
+ _monitors[key] = (val[0], idx)
+ else:
+ raise MonitorError('Expect the format of (variable, index) in the monitor setting. '
+ f'But we got {val}')
+ elif callable(val):
+ _monitors[key] = val
+ else:
+ raise MonitorError('The value of dict monitor expect a sequence with (variable, index) '
+ f'or a callable function. But we got {val}')
+ return _monitors
+
+ def _find_seq_monitor_targets(self, _monitors):
if not isinstance(_monitors, (tuple, list)):
raise TypeError(f'Must be a sequence, but we got {type(_monitors)}')
# get monitor targets
@@ -213,6 +223,43 @@ def _find_monitor_targets(self, _monitors):
monitors[key] = (getattr(master, splits[-1]), index)
return monitors
+ def _find_dict_monitor_targets(self, _monitors):
+ if not isinstance(_monitors, dict):
+ raise TypeError(f'Must be a dict, but we got {type(_monitors)}')
+ # get monitor targets
+ monitors = {}
+ name2node = None
+ for _key, _mon in _monitors.items():
+ if isinstance(_mon, str):
+ if name2node is None:
+ name2node = {node.name: node for node in list(self.target.nodes(level=-1).unique().values())}
+
+ key, index = _mon[0], _mon[1]
+ splits = key.split('.')
+ if len(splits) == 1:
+ if not hasattr(self.target, splits[0]):
+ raise RunningError(f'{self.target} does not has variable {key}.')
+ monitors[key] = (getattr(self.target, splits[-1]), index)
+ else:
+ if not hasattr(self.target, splits[0]):
+ if splits[0] not in name2node:
+ raise MonitorError(f'Cannot find target {key} in monitor of {self.target}, please check.')
+ else:
+ master = name2node[splits[0]]
+ assert len(splits) == 2
+ monitors[key] = (getattr(master, splits[-1]), index)
+ else:
+ master = self.target
+ for s in splits[:-1]:
+ try:
+ master = getattr(master, s)
+ except KeyError:
+ raise MonitorError(f'Cannot find {key} in {master}, please check.')
+ monitors[key] = (getattr(master, splits[-1]), index)
+ else:
+ monitors[_key] = _mon
+ return monitors
+
def __del__(self):
if hasattr(self, 'mon'):
for key in tuple(self.mon.keys()):
diff --git a/brainpy/tools/codes.py b/brainpy/tools/codes.py
index 6a3b9c330..adad57764 100644
--- a/brainpy/tools/codes.py
+++ b/brainpy/tools/codes.py
@@ -4,7 +4,8 @@
import re
from types import LambdaType
-from brainpy.base.base import BrainPyObject
+BrainPyObject = None
+
__all__ = [
'repr_object',
@@ -27,6 +28,9 @@
def repr_object(x):
+ global BrainPyObject
+ if BrainPyObject is None:
+ from brainpy.math import BrainPyObject
if isinstance(x, BrainPyObject):
return x.name
elif callable(x):
diff --git a/brainpy/train/back_propagation.py b/brainpy/train/back_propagation.py
index f3e334596..95fdce99f 100644
--- a/brainpy/train/back_propagation.py
+++ b/brainpy/train/back_propagation.py
@@ -112,7 +112,8 @@ def __init__(
lr = optim.ExponentialDecay(lr=0.025, decay_steps=1, decay_rate=0.99975)
optimizer = optim.Adam(lr=lr)
self.optimizer: optim.Optimizer = optimizer
- self.optimizer.register_vars(self.target.vars(level=-1, include_self=True).subset(bm.TrainVar).unique())
+ if len(self.optimizer.vars_to_train) == 0:
+ self.optimizer.register_vars(self.target.vars(level=-1, include_self=True).subset(bm.TrainVar).unique())
# loss function
self.loss_has_aux = loss_has_aux
@@ -146,7 +147,7 @@ def __repr__(self):
f'{prefix}loss={self._loss_func}, \n\t'
f'{prefix}optimizer={self.optimizer})')
- def get_hist_metric(self, phase='fit', metric='loss', which='detailed'):
+ def get_hist_metric(self, phase='fit', metric='loss', which='report'):
"""Get history losses."""
assert phase in [c.FIT_PHASE, c.TEST_PHASE, c.TRAIN_PHASE, c.PREDICT_PHASE]
assert which in ['report', 'detailed']
@@ -332,7 +333,7 @@ def fit(
self.target.reset_state(self._get_input_batch_size(x))
self.reset_state()
- # training
+ # testing
res = self._get_f_loss(shared_args)(x, y)
# loss
@@ -406,7 +407,7 @@ def _get_f_loss(self, shared_args=None, jit=True) -> Callable:
if self.jit[c.LOSS_PHASE] and jit:
dyn_vars = self.target.vars()
dyn_vars.update(self._dyn_vars)
- dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
+ dyn_vars.update(self.vars(level=0))
self._f_loss_compiled[shared_args_str] = bm.jit(self._f_loss_compiled[shared_args_str],
dyn_vars=dyn_vars.unique())
return self._f_loss_compiled[shared_args_str]
@@ -437,10 +438,15 @@ def _get_f_train(self, shared_args=None) -> Callable:
shared_args_str = serialize_kwargs(shared_args)
if shared_args_str not in self._f_fit_compiled:
- self._f_fit_compiled[shared_args_str] = partial(self._step_func_train, shared_args)
+ self._f_fit_compiled[shared_args_str] = partial(self._step_func_fit, shared_args)
if self.jit[c.FIT_PHASE]:
- dyn_vars = self.vars().unique()
- dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
+ dyn_vars = self.target.vars()
+ dyn_vars.update(self.optimizer.vars())
+ if isinstance(self._loss_func, bm.BrainPyObject):
+ dyn_vars.update(self._loss_func)
+ dyn_vars.update(self._dyn_vars)
+ dyn_vars.update(self.vars(level=0))
+ dyn_vars = dyn_vars.unique()
self._f_fit_compiled[shared_args_str] = bm.jit(self._f_fit_compiled[shared_args_str],
dyn_vars=dyn_vars)
return self._f_fit_compiled[shared_args_str]
@@ -448,7 +454,7 @@ def _get_f_train(self, shared_args=None) -> Callable:
def _step_func_loss(self, shared_args, inputs, targets):
raise NotImplementedError
- def _step_func_train(self, shared_args, inputs, targets):
+ def _step_func_fit(self, shared_args, inputs, targets):
raise NotImplementedError
@@ -491,9 +497,9 @@ def loss_fun(predicts, targets):
Make the monitored results as NumPy arrays.
logger: Any
A file-like object (stream). Used to output the running results. Default is the current `sys.stdout`.
- time_major: bool
- To indicate whether the first axis is the batch size (``time_major=False``) or the
- time length (``time_major=True``).
+ data_first_axis: str
+ To indicate whether the first axis is the batch size (``data_first_axis='B'``) or the
+ time length (``data_first_axis='T'``).
"""
def _step_func_loss(self, shared_args, inputs, targets):
@@ -501,14 +507,14 @@ def _step_func_loss(self, shared_args, inputs, targets):
indices = jnp.arange(num_step, dtype=bm.int_)
times = indices * self.dt + self.t0
indices = indices + self.i0
- if isinstance(self.target.mode, bm.BatchingMode) and not self.time_major:
+ if isinstance(self.target.mode, bm.BatchingMode) and self.data_first_axis == 'B':
inputs = tree_map(lambda x: bm.moveaxis(x, 0, 1), inputs, is_leaf=lambda x: isinstance(x, bm.Array))
inputs = (times, indices, inputs)
outs, mons = self._predict(xs=inputs, shared_args=shared_args)
predicts = (outs, mons) if len(mons) > 0 else outs
return self._loss_func(predicts, targets)
- def _step_func_train(self, shared_args, inputs, targets):
+ def _step_func_fit(self, shared_args, inputs, targets):
res = self._get_f_grad(shared_args)(inputs, targets)
self.optimizer.update(res[0])
return res[1:]
@@ -529,13 +535,13 @@ def _step_func_loss(self, shared_args, inputs, targets):
loss = self._loss_func(outs, targets)
return loss
- def _step_func_train(self, shared_args, inputs, targets):
+ def _step_func_fit(self, shared_args, inputs, targets):
res = self._get_f_grad(shared_args)(inputs, targets)
self.optimizer.update(res[0])
return res[1:]
def _step_func_predict(self, shared, x=None):
- assert not self.time_major, f'There is no time dimension when using the trainer {self.__class__.__name__}.'
+ assert self.data_first_axis == 'B', f'There is no time dimension when using the trainer {self.__class__.__name__}.'
# input step
self.target.clear_input()
diff --git a/brainpy/train/base.py b/brainpy/train/base.py
index 4985be181..ff017a8f2 100644
--- a/brainpy/train/base.py
+++ b/brainpy/train/base.py
@@ -5,7 +5,7 @@
import brainpy.math as bm
from brainpy.dyn.base import DynamicalSystem
from brainpy.dyn.runners import DSRunner
-from brainpy.errors import BrainPyError
+from brainpy.errors import NoLongerSupportError
from brainpy.running import constants as c
from brainpy.types import ArrayType, Output
@@ -43,10 +43,9 @@ def __init__(
super(DSTrainer, self).__init__(target=target, **kwargs)
if not isinstance(self.target.mode, bm.BatchingMode):
- raise BrainPyError(f'''
- From version 2.3.1, DSTrainer must receive a
- DynamicalSystem instance with the computing mode
- of subclass of {bm.batching_mode}.
+ raise NoLongerSupportError(f'''
+ From version 2.3.1, DSTrainer must receive a DynamicalSystem instance with
+ the computing mode of {bm.batching_mode} or {bm.training_mode}.
See https://github.com/brainpy/BrainPy/releases/tag/V2.3.1
for the solution of how to fix this.
diff --git a/brainpy/train/online.py b/brainpy/train/online.py
index d03efba11..836891a4f 100644
--- a/brainpy/train/online.py
+++ b/brainpy/train/online.py
@@ -171,7 +171,7 @@ def fit(
shared['t'] += self.t0
shared['i'] += self.i0
- if not self.time_major:
+ if self.data_first_axis == 'B':
xs = tree_map(lambda x: bm.moveaxis(x, 0, 1),
xs,
is_leaf=lambda x: isinstance(x, bm.Array))
diff --git a/brainpy/types.py b/brainpy/types.py
index a9124e482..f957c6ee4 100644
--- a/brainpy/types.py
+++ b/brainpy/types.py
@@ -1,15 +1,17 @@
# -*- coding: utf-8 -*-
-from typing import TypeVar, Tuple
+from typing import TypeVar, Tuple, Union, Callable
import jax.numpy as jnp
import numpy as np
from brainpy.math.ndarray import Array, Variable, TrainVar
+from brainpy import connect as conn
+from brainpy import initialize as init
__all__ = [
'ArrayType', 'Parameter', 'PyTree',
- 'Shape',
+ 'Shape', 'Initializer',
'Output', 'Monitor'
]
@@ -26,4 +28,6 @@
# component
Output = TypeVar('Output') # noqa
Monitor = TypeVar('Monitor') # noqa
+Connector = Union[conn.Connector, Array, Variable, jnp.ndarray, np.ndarray]
+Initializer = Union[init.Initializer, Callable, Array, Variable, jnp.ndarray, np.ndarray]
diff --git a/changes.md b/changes.md
index 48a3fefde..9c2b1211c 100644
--- a/changes.md
+++ b/changes.md
@@ -1,94 +1,345 @@
+# Change from Version 2.3.0 to Version 2.3.1
-This release continues to add supports for brain-inspired computation.
+
+This release (under the release branch of ``brainpy=2.3.x``) continues to add supports for brain-inspired computation.
+
+
+
+```python
+import brainpy as bp
+import brainpy.math as bm
+```
+
+
+
+## Backwards Incompatible Changes
+
+
+
+#### 1. Error: module 'brainpy' has no attribute 'datasets'
+
+``brainpy.datasets`` module is now published as an independent package ``brainpy_datasets``.
+
+Please change your dataset access from
+
+```python
+bp.datasets.xxxxx
+```
+
+to
+
+```python
+import brainpy_datasets as bp_data
+
+bp_data.chaos.XXX
+bp_data.vision.XXX
+```
+
+For a chaotic data series,
+
+```python
+# old version
+data = bp.datasets.double_scroll_series(t_warmup + t_train + t_test, dt=dt)
+x_var = data['x']
+y_var = data['y']
+z_var = data['z']
+
+# new version
+data = bd.chaos.DoubleScrollEq(t_warmup + t_train + t_test, dt=dt)
+x_var = data.xs
+y_var = data.ys
+z_var = data.zs
+```
+
+For a vision dataset,
+
+```python
+# old version
+dataset = bp.datasets.FashionMNIST(root, train=True, download=True)
+
+# new version
+dataset = bd.vision.FashionMNIST(root, split='train', download=True)
+```
+
+
+
+#### 2. Error: DSTrainer must receive an instance with BatchingMode
+
+This error will happen when using ``brainpy.OnlineTrainer`` , ``brainpy.OfflineTrainer``, ``brainpy.BPTT`` , ``brainpy.BPFF``.
+
+From version 2.3.1, BrainPy explicitly consider the computing mode of each model. For trainers, all training target should be a model with ``BatchingMode`` or ``TrainingMode``.
+
+If you are training model with ``OnlineTrainer`` or ``OfflineTrainer``,
+
+```python
+# old version
+class NGRC(bp.DynamicalSystem):
+ def __init__(self, num_in):
+ super(NGRC, self).__init__()
+ self.r = bp.layers.NVAR(num_in, delay=2, order=3)
+ self.di = bp.layers.Dense(self.r.num_out, num_in)
+
+ def update(self, sha, x):
+ di = self.di(sha, self.r(sha, x))
+ return x + di
+
+
+# new version
+bm.set_enviroment(mode=bm.batching_mode)
+
+class NGRC(bp.DynamicalSystem):
+ def __init__(self, num_in):
+ super(NGRC, self).__init__()
+ self.r = bp.layers.NVAR(num_in, delay=2, order=3)
+ self.di = bp.layers.Dense(self.r.num_out, num_in, mode=bm.training_mode)
+
+ def update(self, sha, x):
+ di = self.di(sha, self.r(sha, x))
+ return x + di
+```
+
+ If you are training models with ``BPTrainer``, adding the following line at the top of the script,
+
+```python
+bm.set_enviroment(mode=bm.training_mode)
+```
+
+
+
+#### 3. Error: inputs_are_batching is no longer supported.
+
+This is because if the training target is in ``batching`` mode, this has already indicated that the inputs should be batching.
+
+Simple remove the ``inputs_are_batching`` from your functional call of ``.predict()`` will solve the issue.
+
## New Features
-1. ``brainpy.encoding`` module for encoding rate values into spike trains. Currently, we support
- - `brainpy.encoding.LatencyEncoder`
- - `brainpy.encoding.PoissonEncoder`
- - `brainpy.encoding.WeightedPhaseEncoder`
-2. ``brainpy.math.surrogate`` module for surrogate gradient functions. Currently, we support
+### 1. ``brainpy.math`` module upgrade
- - `brainpy.math.surrogate.arctan`
- - `brainpy.math.surrogate.erf`
- - `brainpy.math.surrogate.gaussian_grad`
- - `brainpy.math.surrogate.inv_square_grad`
- - `brainpy.math.surrogate.leaky_relu`
- - `brainpy.math.surrogate.log_tailed_relu`
- - `brainpy.math.surrogate.multi_gaussian_grad`
- - `brainpy.math.surrogate.nonzero_sign_log`
- - `brainpy.math.surrogate.one_input`
- - `brainpy.math.surrogate.piecewise_exp`
- - `brainpy.math.surrogate.piecewise_leaky_relu`
- - `brainpy.math.surrogate.piecewise_quadratic`
- - `brainpy.math.surrogate.q_pseudo_spike`
- - `brainpy.math.surrogate.relu_grad`
- - `brainpy.math.surrogate.s2nn`
- - `brainpy.math.surrogate.sigmoid`
- - `brainpy.math.surrogate.slayer_grad`
- - `brainpy.math.surrogate.soft_sign`
- - `brainpy.math.surrogate.squarewave_fourier_series`
+#### ``brainpy.math.surrogate`` module for surrogate gradient functions.
-3. ``brainpy.dyn.transfom`` module for transforming a ``DynamicalSystem`` instance to a callable ``BrainPyObject``. Specifically, we provide
+Currently, we support
- - `LoopOverTime` for unrolling a dynamical system over time.
- - `NoSharedArg` for removing the dependency of shared arguments.
+- `brainpy.math.surrogate.arctan`
+- `brainpy.math.surrogate.erf`
+- `brainpy.math.surrogate.gaussian_grad`
+- `brainpy.math.surrogate.inv_square_grad`
+- `brainpy.math.surrogate.leaky_relu`
+- `brainpy.math.surrogate.log_tailed_relu`
+- `brainpy.math.surrogate.multi_gaussian_grad`
+- `brainpy.math.surrogate.nonzero_sign_log`
+- `brainpy.math.surrogate.one_input`
+- `brainpy.math.surrogate.piecewise_exp`
+- `brainpy.math.surrogate.piecewise_leaky_relu`
+- `brainpy.math.surrogate.piecewise_quadratic`
+- `brainpy.math.surrogate.q_pseudo_spike`
+- `brainpy.math.surrogate.relu_grad`
+- `brainpy.math.surrogate.s2nn`
+- `brainpy.math.surrogate.sigmoid`
+- `brainpy.math.surrogate.slayer_grad`
+- `brainpy.math.surrogate.soft_sign`
+- `brainpy.math.surrogate.squarewave_fourier_series`
-4. Change all ``brainpy.Runner`` as the subclasses of ``BrainPyObject``, which means that all ``brainpy.Runner`` can be used as a part of the high-level program or transformation.
-5. Enable the continuous running of a differential equation (ODE, SDE, FDE, DDE, etc.) with `IntegratorRunner`. For example,
- ```python
- import brainpy as bp
-
- # differential equation
- a, b, tau = 0.7, 0.8, 12.5
- dV = lambda V, t, w, Iext: V - V * V * V / 3 - w + Iext
- dw = lambda w, t, V: (V + a - b * w) / tau
- fhn = bp.odeint(bp.JointEq([dV, dw]), method='rk4', dt=0.1)
-
- # differential integrator runner
- runner = bp.IntegratorRunner(fhn,
- monitors=['V', 'w'],
- inits=[1., 1.])
-
- # run 1
- Iext, duration = bp.inputs.section_input([0., 1., 0.5], [200, 200, 200], return_length=True)
- runner.run(duration, dyn_args=dict(Iext=Iext))
- bp.visualize.line_plot(runner.mon.ts, runner.mon['V'], legend='V')
-
- # run 2
- Iext, duration = bp.inputs.section_input([0.5], [200], return_length=True)
- runner.run(duration, dyn_args=dict(Iext=Iext))
- bp.visualize.line_plot(runner.mon.ts, runner.mon['V'], legend='V-run2', show=True)
-
- ```
+#### New transformation function ``brainpy.math.to_dynsys``
-6. New transformation function ``brainpy.math.to_dynsys`` supports to transform a pure Python function into a ``DynamicalSystem``. This will be useful when running a `DynamicalSystem` with arbitrary customized inputs.
+New transformation function ``brainpy.math.to_dynsys`` supports to transform a pure Python function into a ``DynamicalSystem``. This will be useful when running a `DynamicalSystem` with arbitrary customized inputs.
- ```python
- import brainpy.math as bm
-
- hh = bp.neurons.HH(1)
-
- @bm.to_dynsys(child_objs=hh)
- def run_hh(tdi, x=None):
- if x is not None:
- hh.input += x
-
- runner = bp.DSRunner(run_hhh, monitors={'v': hh.V})
- runner.run(inputs=bm.random.uniform(3, 6, 1000))
- ```
+```python
+import brainpy.math as bm
+
+hh = bp.neurons.HH(1)
+
+@bm.to_dynsys(child_objs=hh)
+def run_hh(tdi, x=None):
+ if x is not None:
+ hh.input += x
+
+runner = bp.DSRunner(run_hhh, monitors={'v': hh.V})
+runner.run(inputs=bm.random.uniform(3, 6, 1000))
+```
+
+
+
+#### Default data types
+
+Default data types `brainpy.math.int_`, `brainpy.math.float_` and `brainpy.math.complex_` are initialized according to the default `x64` settings. Then, these data types can be set or get by `brainpy.math.set_*` or `brainpy.math.get_*` syntaxes.
+
+Take default integer type ``int_`` as an example,
+
+```python
+# set the default integer type
+bm.set_int_(jax.numpy.int64)
+
+# get the default integer type
+a1 = bm.asarray([1], dtype=bm.int_)
+a2 = bm.asarray([1], dtype=bm.get_int()) # equivalent
+```
+
+Default data types are changed according to the `x64` setting of JAX. For instance,
+
+```python
+bm.enable_x64()
+assert bm.int_ == jax.numpy.int64
+bm.disable_x64()
+assert bm.int_ == jax.numpy.int32
+```
+
+``brainpy.math.float_`` and ``brainpy.math.complex_`` behaves similarly with ``brainpy.math.int_``.
+
+
+
+#### Environment context manager
+
+This release introduces a new concept ``computing environment`` in BrainPy. Computing environment is a default setting for current computation jobs, including the default data type (``int_``, ``float_``, ``complex_``), the default numerical integration precision (``dt``), the default computing mode (``mode``). All models, arrays, and computations using the default setting will be carried out under the environment setting.
+
+Users can set a default environment through
+
+```python
+brainpy.math.set_environment(mode, dt, x64)
+```
+
+However, ones can also construct models or perform computation through a temporal environment context manager, this can be implemented through:
+
+```python
+# constructing a HH model with dt=0.1 and x64 precision
+with bm.environment(mode, dt=0.1, x64=True):
+ hh1 = bp.neurons.HH(1)
+
+# constructing a HH model with dt=0.05 and x32 precision
+with bm.environment(mode, dt=0.05, x64=False):
+ hh2 = bp.neuron.HH(1)
+```
+
+Usually, users construct models for either brain-inspired computing (``training mode``) or brain simulation (``nonbatching mode``), therefore, there are shortcut context manager for setting a training environment or batching environment:
+
+```python
+with bm.training_environment(dt, x64):
+ pass
+
+with bm.batching_environment(dt, x64):
+ pass
+```
+
+
+
+### 2. ``brainpy.dyn`` module
+
+
+
+#### ``brainpy.dyn.transfom`` module for transforming a ``DynamicalSystem`` instance to a callable ``BrainPyObject``.
-7.
+Specifically, we provide
-8.
+- `LoopOverTime` for unrolling a dynamical system over time.
+- `NoSharedArg` for removing the dependency of shared arguments.
+
+
+
+
+
+### 3. Running supports in BrainPy
+
+
+
+#### All ``brainpy.Runner`` now are subclasses of ``BrainPyObject``
+
+This means that all ``brainpy.Runner`` can be used as a part of the high-level program or transformation.
+
+
+
+#### Enable the continuous running of a differential equation (ODE, SDE, FDE, DDE, etc.) with `IntegratorRunner`.
+
+For example,
+
+```python
+import brainpy as bp
+
+# differential equation
+a, b, tau = 0.7, 0.8, 12.5
+dV = lambda V, t, w, Iext: V - V * V * V / 3 - w + Iext
+dw = lambda w, t, V: (V + a - b * w) / tau
+fhn = bp.odeint(bp.JointEq([dV, dw]), method='rk4', dt=0.1)
+
+# differential integrator runner
+runner = bp.IntegratorRunner(fhn, monitors=['V', 'w'], inits=[1., 1.])
+
+# run 1
+Iext, duration = bp.inputs.section_input([0., 1., 0.5], [200, 200, 200], return_length=True)
+runner.run(duration, dyn_args=dict(Iext=Iext))
+bp.visualize.line_plot(runner.mon.ts, runner.mon['V'], legend='V')
+
+# run 2
+Iext, duration = bp.inputs.section_input([0.5], [200], return_length=True)
+runner.run(duration, dyn_args=dict(Iext=Iext))
+bp.visualize.line_plot(runner.mon.ts, runner.mon['V'], legend='V-run2', show=True)
+
+```
+
+
+
+#### Enable call a customized function during fitting of ``brainpy.BPTrainer``.
+
+This customized function (provided through ``fun_after_report``) will be useful to save a checkpoint during the training. For instance,
+
+```python
+class CheckPoint:
+ def __init__(self, path='path/to/directory/'):
+ self.max_acc = 0.
+ self.path = path
+
+ def __call__(self, idx, metrics, phase):
+ if phase == 'test' and metrics['acc'] > self.max_acc:
+ self.max_acc = matrics['acc']
+ bp.checkpoints.save(self.path, net.state_dict(), idx)
+
+trainer = bp.BPTT()
+trainer.fit(..., fun_after_report=CheckPoint())
+```
+
+
+
+#### Enable data with ``data_first_axis`` format when predicting or fitting in a ``brainpy.DSRunner`` and ``brainpy.DSTrainer``.
+
+Previous version of BrainPy only supports data with the batch dimension at the first axis. Currently, ``brainpy.DSRunner`` and ``brainpy.DSTrainer`` can support the data with the time dimension at the first axis. This can be set through ``data_first_axis='T'`` when initializing a runner or trainer.
+
+```python
+runner = bp.DSRunner(..., data_first_axis='T')
+trainer = bp.DSTrainer(..., data_first_axis='T')
+```
+
+
+
+### 4. Utility in BrainPy
+
+
+
+#### ``brainpy.encoding`` module for encoding rate values into spike trains
+
+ Currently, we support
+
+- `brainpy.encoding.LatencyEncoder`
+- `brainpy.encoding.PoissonEncoder`
+- `brainpy.encoding.WeightedPhaseEncoder`
+
+
+
+#### ``brainpy.checkpoints`` module for model state serialization.
+
+This version of BrainPy supports to save a checkpoint of the model into the physical disk. Inspired from the Flax API, we provide the following checkpoint APIs:
+
+- ``brainpy.checkpoints.save()`` for saving a checkpoint of the model.
+- ``brainpy.checkpoints.multiprocess_save()`` for saving a checkpoint of the model in multi-process environment.
+- ``brainpy.checkpoints.load()`` for loading the last or best checkpoint from the given checkpoint path.
+- ``brainpy.checkpoints.load_latest()`` for retrieval the path of the latest checkpoint in a directory.
@@ -97,7 +348,12 @@ This release continues to add supports for brain-inspired computation.
## Deprecations
-1. ``func_monitors`` is no longer supported in all ``brainpy.Runner`` subclasses. We will remove its supports since version 2.4.0. Instead, monitoring with a dict of callable functions can be set in ``monitors``. For example,
+
+### 1. Deprecations in the running supports of BrainPy
+
+#### ``func_monitors`` is no longer supported in all ``brainpy.Runner`` subclasses.
+
+We will remove its supports since version 2.4.0. Instead, monitoring with a dict of callable functions can be set in ``monitors``. For example,
```python
@@ -116,64 +372,91 @@ This release continues to add supports for brain-inspired computation.
'sp10': model.spike[10]})
```
-2. ``func_inputs`` is no longer supported in all ``brainpy.Runner`` subclasses. Instead, giving inputs with a callable function should be done with ``inputs``.
- ```python
- # old version
-
- net = EINet()
-
- def f_input(tdi):
- net.E.input += 10.
-
- runner = bp.DSRunner(net, fun_inputs=f_input, inputs=('I.input', 10.))
- ```
- ```python
- # new version
-
- def f_input(tdi):
- net.E.input += 10.
- net.I.input += 10.
- runner = bp.DSRunner(net, inputs=f_input)
- ```
+#### ``func_inputs`` is no longer supported in all ``brainpy.Runner`` subclasses.
-3. ``inputs_are_batching`` is deprecated in ``predict()``/``.run()`` of all ``brainpy.Runner`` subclasses.
+ Instead, giving inputs with a callable function should be done with ``inputs``.
-4. ``args`` and ``dyn_args`` are now deprecated in ``IntegratorRunner``. Instead, users should specify ``args`` and ``dyn_args`` when using ``IntegratorRunner.run()`` function.
+```python
+# old version
- ```python
- dV = lambda V, t, w, I: V - V * V * V / 3 - w + I
- dw = lambda w, t, V, a, b: (V + a - b * w) / 12.5
- integral = bp.odeint(bp.JointEq([dV, dw]), method='exp_auto')
-
- # old version
- runner = bp.IntegratorRunner(
- integral,
- monitors=['V', 'w'],
- inits={'V': bm.random.rand(10), 'w': bm.random.normal(size=10)},
- args={'a': 1., 'b': 1.}, # CHANGE
- dyn_args={'I': bp.inputs.ramp_input(0, 4, 100)}, # CHANGE
- )
- runner.run(100.,)
-
- ```
+net = EINet()
- ```python
- # new version
- runner = bp.IntegratorRunner(
- integral,
- monitors=['V', 'w'],
- inits={'V': bm.random.rand(10), 'w': bm.random.normal(size=10)},
- )
- runner.run(100.,
- args={'a': 1., 'b': 1.},
- dyn_args={'I': bp.inputs.ramp_input(0, 4, 100)})
- ```
+def f_input(tdi):
+ net.E.input += 10.
+
+runner = bp.DSRunner(net, fun_inputs=f_input, inputs=('I.input', 10.))
+```
+
+```python
+# new version
+
+def f_input(tdi):
+ net.E.input += 10.
+ net.I.input += 10.
+runner = bp.DSRunner(net, inputs=f_input)
+```
+
+
+
+#### ``inputs_are_batching`` is deprecated.
+
+``inputs_are_batching`` is deprecated in ``predict()``/``.run()`` of all ``brainpy.Runner`` subclasses.
+
+
+
+#### ``args`` and ``dyn_args`` are now deprecated in ``IntegratorRunner``.
+
+Instead, users should specify ``args`` and ``dyn_args`` when using ``IntegratorRunner.run()`` function.
+
+```python
+dV = lambda V, t, w, I: V - V * V * V / 3 - w + I
+dw = lambda w, t, V, a, b: (V + a - b * w) / 12.5
+integral = bp.odeint(bp.JointEq([dV, dw]), method='exp_auto')
+
+# old version
+runner = bp.IntegratorRunner(
+ integral,
+ monitors=['V', 'w'],
+ inits={'V': bm.random.rand(10), 'w': bm.random.normal(size=10)},
+ args={'a': 1., 'b': 1.}, # CHANGE
+ dyn_args={'I': bp.inputs.ramp_input(0, 4, 100)}, # CHANGE
+)
+runner.run(100.,)
+
+```
+
+```python
+# new version
+runner = bp.IntegratorRunner(
+ integral,
+ monitors=['V', 'w'],
+ inits={'V': bm.random.rand(10), 'w': bm.random.normal(size=10)},
+)
+runner.run(100.,
+ args={'a': 1., 'b': 1.},
+ dyn_args={'I': bp.inputs.ramp_input(0, 4, 100)})
+```
+
+
+
+### 2. Deprecations in ``brainpy.math`` module
+#### `ditype()` and `dftype()` are deprecated.
+`brainpy.math.ditype()` and `brainpy.math.dftype()` are deprecated. Using `brainpy.math.int_` and `brainpy.math.float()` instead.
+#### ``brainpy.modes`` module is now moved into ``brainpy.math``
+The correspondences are listed as the follows:
+- ``brainpy.modes.Mode`` => ``brainpy.math.Mode``
+- ``brainpy.modes.NormalMode `` => ``brainpy.math.NonBatchingMode``
+- ``brainpy.modes.BatchingMode `` => ``brainpy.math.BatchingMode``
+- ``brainpy.modes.TrainingMode `` => ``brainpy.math.TrainingMode``
+- ``brainpy.modes.normal `` => ``brainpy.math.nonbatching_mode``
+- ``brainpy.modes.batching `` => ``brainpy.math.batching_mode``
+- ``brainpy.modes.training `` => ``brainpy.math.training_mode``
diff --git a/docs/_static/dyn_models.svg b/docs/_static/dyn_models.svg
deleted file mode 100644
index 78df73923..000000000
--- a/docs/_static/dyn_models.svg
+++ /dev/null
@@ -1,4 +0,0 @@
-
-
-
-
\ No newline at end of file
diff --git a/docs/apis/math.rst b/docs/apis/math.rst
index 30fec794a..a68e57194 100644
--- a/docs/apis/math.rst
+++ b/docs/apis/math.rst
@@ -15,5 +15,4 @@
auto/math/surrogate
auto/math/delayvars
auto/math/activations
- math/comparison
auto/math/environment
diff --git a/docs/apis/math/comparison.rst b/docs/apis/math/comparison.rst
deleted file mode 100644
index c8b966cdb..000000000
--- a/docs/apis/math/comparison.rst
+++ /dev/null
@@ -1,9 +0,0 @@
-Comparison Table
-================
-
-Here is a list of NumPy APIs and its corresponding BrainPy implementations.
-
-``-`` in BrainPy column denotes that implementation is not provided yet.
-We welcome contributions for these functions.
-
-.. include:: ../auto/math/comparison_table.rst.inc
diff --git a/docs/auto_generater.py b/docs/auto_generater.py
index 79583ffb7..d4f774faa 100644
--- a/docs/auto_generater.py
+++ b/docs/auto_generater.py
@@ -207,22 +207,6 @@ def generate_train_docs(path='apis/auto/train/'):
template=False)
-def generate_base_docs(path='apis/auto/'):
- if not os.path.exists(path):
- os.makedirs(path)
-
- module_and_name = [
- ('base', 'Base Class'),
- ('function', 'Function Wrapper'),
- ('collector', 'Collectors'),
- ('io', 'Exporting and Loading'),
- ('naming', 'Naming Tools')]
- write_submodules(module_name='brainpy.base',
- filename=os.path.join(path, 'base.rst'),
- header='``brainpy.base`` module',
- submodule_names=[k[0] for k in module_and_name],
- section_names=[k[1] for k in module_and_name])
-
def generate_connect_docs(path='apis/auto/'):
if not os.path.exists(path):
@@ -452,35 +436,21 @@ def generate_math_docs(path='apis/auto/math/'):
if not os.path.exists(path):
os.makedirs(path)
- buf = []
- buf += _section(header='Multi-dimensional Array',
- numpy_mod='numpy',
- brainpy_mod='brainpy.math',
- jax_mod='jax.numpy',
- klass='ndarray', )
- buf += _section(header='Array Operations',
- numpy_mod='numpy',
- brainpy_mod='brainpy.math',
- jax_mod='jax.numpy',
- is_jax=True)
- buf += _section(header='Linear Algebra',
- numpy_mod='numpy.linalg',
- brainpy_mod='brainpy.math.linalg',
- jax_mod='jax.numpy.linalg', )
- buf += _section(header='Discrete Fourier Transform',
- numpy_mod='numpy.fft',
- brainpy_mod='brainpy.math.fft',
- jax_mod='jax.numpy.fft', )
- buf += _section(header='Random Sampling',
- numpy_mod='numpy.random',
- brainpy_mod='brainpy.math.random',
- jax_mod='jax.random',)
- codes = '\n'.join(buf)
+ module_and_name = [
+ ('base_object', 'Basic BrainPy Object'),
+ ('collector', 'Basic Variable Collector'),
+ ('base_transform', 'Basic Transformation Object'),
+ ('autograd', 'Automatic Differentiation'),
+ ('controls', 'Control Flows'),
+ ('jit', 'JIT Compilation'),
+ ('function', 'Function to Object'),
+ ]
+ write_submodules(module_name='brainpy.math.object_transform',
+ filename=os.path.join(path, 'object_transform.rst'),
+ header='Object-oriented Transformation',
+ submodule_names=[k[0] for k in module_and_name],
+ section_names=[k[1] for k in module_and_name])
- if not os.path.exists(path):
- os.makedirs(path)
- with open(os.path.join(path, 'comparison_table.rst.inc'), 'w') as f:
- f.write(codes)
module_and_name = [
('pre_syn_post', '``pre-syn-post`` Transformations',),
@@ -493,19 +463,6 @@ def generate_math_docs(path='apis/auto/math/'):
submodule_names=[k[0] for k in module_and_name],
section_names=[k[1] for k in module_and_name])
- module_and_name = [
- ('autograd', 'Automatic Differentiation'),
- ('controls', 'Control Flows'),
- # ('parallels', 'Parallel Compilation'),
- ('jit', 'JIT Compilation'),
- ('function', 'Function to Object'),
- ]
- write_submodules(module_name='brainpy.math.object_transform',
- filename=os.path.join(path, 'object_transform.rst'),
- header='Object-oriented Transformation',
- submodule_names=[k[0] for k in module_and_name],
- section_names=[k[1] for k in module_and_name])
-
write_module(module_name='brainpy.math.surrogate',
filename=os.path.join(path, 'surrogate.rst'),
header='Surrogate Gradient Functions')
@@ -532,7 +489,6 @@ def generate_measure_docs(path='apis/auto/'):
header='``brainpy.measure`` module')
-
def generate_optimizers_docs(path='apis/auto/'):
if not os.path.exists(path):
os.makedirs(path)
diff --git a/docs/conf.py b/docs/conf.py
index 7fc6029ca..739038c2e 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -20,7 +20,6 @@
import brainpy
from docs import auto_generater
-auto_generater.generate_base_docs()
auto_generater.generate_analysis_docs()
auto_generater.generate_train_docs()
auto_generater.generate_algorithm_docs()
diff --git a/docs/core_concept/brainpy_dynamical_system.ipynb b/docs/core_concept/brainpy_dynamical_system.ipynb
new file mode 100644
index 000000000..8a78a9713
--- /dev/null
+++ b/docs/core_concept/brainpy_dynamical_system.ipynb
@@ -0,0 +1,508 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# Concept 2: Dynamical System"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "@[Chaoming Wang](https://github.com/chaoming0625)"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "BrainPy supports models in brain simulation and brain-inspired computing.\n",
+ "\n",
+ "All these supports are based on one common concept: **Dynamical System** via ``brainpy.DynamicalSystem``.\n",
+ "\n",
+ "Therefore, it is essential to understand:\n",
+ "1. what is ``brainpy.DynamicalSystem``?\n",
+ "2. how to define ``brainpy.DynamicalSystem``?\n",
+ "3. how to run ``brainpy.DynamicalSystem``?"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "outputs": [
+ {
+ "data": {
+ "text/plain": "'2.3.1'"
+ },
+ "execution_count": 25,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import brainpy as bp\n",
+ "import brainpy.math as bm\n",
+ "\n",
+ "bm.set_platform('cpu')\n",
+ "\n",
+ "bp.__version__"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## What is ``DynamicalSystem``?"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "All models used in brain simulation and brain-inspired computing is ``DynamicalSystem``.\n",
+ "\n",
+ "A ``DynamicalSystem`` defines the updating rule of the model at single time step.\n",
+ "\n",
+ "1. For models with state, ``DynamicalSystem`` defines the state transition from $t$ to $t+dt$, i.e., $S(t+dt) = F\\left(S(t), x, t, dt\\right)$, where $S$ is the state, $x$ is input, $t$ is the time, and $dt$ is the time step. This is the case for recurrent neural networks (like GRU, LSTM), neuron models (like HH, LIF), or synapse models which are widely used in brain simulation.\n",
+ "\n",
+ "2. However, for models in deep learning, like convolution and fully-connected linear layers, ``DynamicalSystem`` defines the input-to-output mapping, i.e., $y=F\\left(x, t\\right)$."
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "![](imgs/dynamical_system.png)"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## How to define ``DynamicalSystem``?"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Keep in mind that the usage of ``DynamicalSystem`` has several constraints in BrainPy."
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### 1. ``.update()`` function\n",
+ "\n",
+ "First, *all ``DynamicalSystem`` should implement ``.update()`` function*, which receives two arguments:\n",
+ "\n",
+ "- `s` (or named as others): A dict, to indicate shared arguments across all nodes/layers in the network, like\n",
+ " - the current time ``t``, or\n",
+ " - the current running index ``i``, or\n",
+ " - the current time step ``dt``, or\n",
+ " - the current phase of training or testing ``fit=True/False``.\n",
+ "- `x` (or named as others): The individual input for this node/layer.\n",
+ "\n",
+ "We call `s` as shared arguments because they are shared and same for all nodes/layers at current time step. On the contrary, different nodes/layers have different input `x`."
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "**Example: LIF neuron model for brain simulation**\n",
+ "\n",
+ "Here we illustrate the first constraint of ``DynamicalSystem`` using the Leaky Integrate-and-Fire (LIF) model.\n",
+ "\n",
+ "The LIF model is firstly proposed in brain simulation for modeling neuron dynamics. Its equation is given by\n",
+ "\n",
+ "$$\n",
+ "\\begin{aligned}\n",
+ "\\tau_m \\frac{dV}{dt} = - (V(t) - V_{rest}) + I(t) \\\\\n",
+ "\\text{if} \\, V(t) \\gt V_{th}, V(t) =V_{rest}\n",
+ "\\end{aligned}\n",
+ "$$\n",
+ "\n",
+ "For the details of the model, users should refer to Wikipedia or other resource.\n"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "outputs": [],
+ "source": [
+ "class LIF_for_BrainSimulation(bp.DynamicalSystem):\n",
+ " def __init__(self, size, V_rest=0., V_th=1., tau=5., mode=None):\n",
+ " super().__init__(mode=mode)\n",
+ "\n",
+ " # this model only supports non-batching mode\n",
+ " bp.check.is_subclass(self.mode, bm.NonBatchingMode)\n",
+ "\n",
+ " # parameters\n",
+ " self.size = size\n",
+ " self.V_rest = V_rest\n",
+ " self.V_th = V_th\n",
+ " self.tau = tau\n",
+ "\n",
+ " # variables\n",
+ " self.V = bm.Variable(bm.ones(size) * V_rest)\n",
+ " self.spike = bm.Variable(bm.zeros(size, dtype=bool))\n",
+ "\n",
+ " # integrate differential equation with exponential euler method\n",
+ " self.integral = bp.odeint(f=lambda V, t, I: (-V + V_rest + I)/tau, method='exp_auto')\n",
+ "\n",
+ " def update(self, s, x):\n",
+ " # define how the model states update\n",
+ " # according to the external input\n",
+ " t, dt = s.get('t'), s.get('dt', bm.dt)\n",
+ " V = self.integral(self.V, t, x, dt=dt)\n",
+ " spike = V >= self.V_th\n",
+ " self.V.value = bm.where(spike, self.V_rest, V)\n",
+ " self.spike.value = spike\n",
+ " return spike"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### 2. Computing mode\n",
+ "\n",
+ "Second, **explicitly consider which computing mode your ``DynamicalSystem`` supports**.\n",
+ "\n",
+ "Brain simulation usually constructs models without batching dimension (we refer to it as *non-batching mode*, as seen in above LIF model), while brain-inspired computation trains models with a batch of data (*batching mode* or *training mode*).\n",
+ "\n",
+ "So, to write a model applicable to the abroad applications in brain simulation and brain-inspired computing, you need to consider which mode your model supports, one of them, or both of them."
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "**Example: LIF neuron model for both brain simulation and brain-inspired computing**\n",
+ "\n",
+ "When considering the computing mode, we can program a general LIF model for brain simulation and brain-inspired computing.\n",
+ "\n",
+ "To overcome the non-differential property of the spike in the LIF model for brain simulation, for the code\n",
+ "\n",
+ "```python\n",
+ "spike = V >= self.V_th\n",
+ "```\n",
+ "\n",
+ "LIF models used in brain-inspired computing calculate the spiking state using the surrogate gradient function, i.e., replacing the backward gradient with a smooth function, like\n",
+ "\n",
+ "$$\n",
+ "g'(x) = \\frac{1}{(\\alpha * |x| + 1.) ^ 2}\n",
+ "$$"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "outputs": [],
+ "source": [
+ "class LIF(bp.DynamicalSystem):\n",
+ " def __init__(self, size, f_surrogate=None, V_rest=0., V_th=1., tau=5.,mode=None):\n",
+ " super().__init__(mode=mode)\n",
+ " bp.check.is_subclass(self.mode, [bm.NonBatchingMode, bm.BatchingMode, bm.TrainingMode])\n",
+ "\n",
+ " # Parameters\n",
+ " self.size = size\n",
+ " self.num = bp.tools.size2num(size)\n",
+ " self.V_rest = V_rest\n",
+ " self.V_th = V_th\n",
+ " self.tau = tau\n",
+ " if f_surrogate is None:\n",
+ " f_surrogate = bm.surrogate.inv_square_grad\n",
+ " self.f_surrogate = f_surrogate\n",
+ "\n",
+ " # integrate differential equation with exponential euler method\n",
+ " self.integral = bp.odeint(f=lambda V, t, I: (-V + V_rest + I)/tau, method='exp_auto')\n",
+ "\n",
+ " # Initialize a Variable:\n",
+ " # - if non-batching mode, batch axis of V is None\n",
+ " # - if batching mode, batch axis of V is 0\n",
+ " self.V = bp.init.variable_(bm.zeros, self.size, self.mode)\n",
+ " self.V[:] = self.V_rest\n",
+ " self.spike = bp.init.variable_(bm.zeros, self.size, self.mode)\n",
+ "\n",
+ " def reset_state(self, batch_size=None):\n",
+ " self.V.value = bp.init.variable_(bm.ones, self.size, batch_size) * self.V_rest\n",
+ " self.spike.value = bp.init.variable_(bm.zeros, self.size, batch_size)\n",
+ "\n",
+ " def update(self, s, x):\n",
+ " t, dt = s.get('t'), s.get('dt', bm.dt)\n",
+ " V = self.integral(self.V, t, x, dt=dt)\n",
+ " # replace non-differential heaviside function\n",
+ " # with a surrogate gradient function\n",
+ " spike = self.f_surrogate(V - self.V_th)\n",
+ " # reset membrane potential\n",
+ " self.V.value = (1. - spike) * V + spike * self.V_rest\n",
+ " self.spike.value = spike\n",
+ " return spike"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### Model composition\n",
+ "\n",
+ "The ``LIF`` model we have defined above can be recursively composed to construct networks in brain simulation and brain-inspired computing."
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "outputs": [],
+ "source": [
+ "class EINet(bp.DynamicalSystem):\n",
+ " def __init__(self, num_exc, num_inh):\n",
+ " super().__init__()\n",
+ " self.E = LIF(num_exc, V_rest=-55, V_th=-50., tau=20.)\n",
+ " self.I = LIF(num_inh, V_rest=-55, V_th=-50., tau=20.)\n",
+ " self.E2E = bp.synapses.Exponential(self.E, self.E, bp.conn.FixedProb(0.02),\n",
+ " g_max=1.62, tau=5., output=None)\n",
+ " self.E2I = bp.synapses.Exponential(self.E, self.I, bp.conn.FixedProb(0.02),\n",
+ " g_max=1.62, tau=5., output=None)\n",
+ " self.I2E = bp.synapses.Exponential(self.I, self.E, bp.conn.FixedProb(0.02),\n",
+ " g_max=-9.0, tau=10., output=None)\n",
+ " self.I2I = bp.synapses.Exponential(self.I, self.I, bp.conn.FixedProb(0.02),\n",
+ " g_max=-9.0, tau=10., output=None)\n",
+ "\n",
+ " def update(self, s, x):\n",
+ " # x is the background input\n",
+ " e2e = self.E2E(s)\n",
+ " e2i = self.E2I(s)\n",
+ " i2e = self.I2E(s)\n",
+ " i2i = self.I2I(s)\n",
+ " self.E(s, e2e + i2e + x)\n",
+ " self.I(s, e2i + i2i + x)\n",
+ "\n",
+ "with bm.environment(mode=bm.nonbatching_mode):\n",
+ " net1 = EINet(3200, 800)"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Here the ``EINet`` defines an E/I balanced network which is a classical network model in brain simulation. The following ``AINet`` utilizes the LIF model to construct a model for AI training."
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "outputs": [],
+ "source": [
+ "# This network can be used in AI applications\n",
+ "\n",
+ "class AINet(bp.DynamicalSystem):\n",
+ " def __init__(self, sizes):\n",
+ " super().__init__()\n",
+ " self.neu1 = LIF(sizes[0])\n",
+ " self.syn1 = bp.layers.Dense(sizes[0], sizes[1])\n",
+ " self.neu2 = LIF(sizes[1])\n",
+ " self.syn2 = bp.layers.Dense(sizes[1], sizes[2])\n",
+ " self.neu3 = LIF(sizes[2])\n",
+ "\n",
+ " def update(self, s, x):\n",
+ " x = self.neu1(s, x)\n",
+ " x = self.syn1(s, x)\n",
+ " x = self.neu2(s, x)\n",
+ " x = self.syn2(s, x)\n",
+ " x = self.neu3(s, x)\n",
+ " return x\n",
+ "\n",
+ "with bm.environment(mode=bm.training_mode):\n",
+ " net2 = AINet([100, 50, 10])"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## How to run ``DynamicalSystem``?"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "As we have stated above that ``DynamicalSystem`` only defines the updating rule at single time step, to run a ``DynamicalSystem`` instance over time, we need a for loop mechanism.\n",
+ "\n",
+ "![](./imgs/dynamical_system_and_dsrunner.png)"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### 1. ``brainpy.math.for_loop``\n",
+ "\n",
+ "``for_loop`` is a structural control flow API which runs a function with the looping over the inputs.\n",
+ "\n",
+ "Suppose we have 200 time steps with the step size of 0.1, we can run the model with:"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "outputs": [
+ {
+ "data": {
+ "text/plain": "(200, 10, 10)"
+ },
+ "execution_count": 23,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "with bm.environment(dt=0.1):\n",
+ " # construct a set of shared argument with the given time steps\n",
+ " shared = bm.shared_args_over_time(num_step=200)\n",
+ " # construct the inputs with shape of (time, batch, feature)\n",
+ " currents = bm.random.rand(200, 10, 100)\n",
+ "\n",
+ " # run the model\n",
+ " net2.reset_state(batch_size=10)\n",
+ " out = bm.for_loop(net2, (shared, currents))\n",
+ "\n",
+ "out.shape"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "### 2. ``brainpy.DSRunner`` and ``brainpy.DSTrainer``\n",
+ "\n",
+ "Another way to run the model in BrainPy is using the structural running object ``DSRunner`` and ``DSTrainer``. They provide more flexible way to monitoring the variables in a ``DynamicalSystem``.\n"
+ ],
+ "metadata": {
+ "collapsed": false
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "outputs": [
+ {
+ "data": {
+ "text/plain": " 0%| | 0/1000 [00:00, ?it/s]",
+ "application/vnd.jupyter.widget-view+json": {
+ "version_major": 2,
+ "version_minor": 0,
+ "model_id": "877e3a333fef466e8189a9db2a99fa2a"
+ }
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "text/plain": "