diff --git a/brainpy/_src/checkpoints/io.py b/brainpy/_src/checkpoints/io.py index bf254bf0e..4e712c5ca 100644 --- a/brainpy/_src/checkpoints/io.py +++ b/brainpy/_src/checkpoints/io.py @@ -151,7 +151,7 @@ def save_as_h5(filename: str, variables: dict): raise ValueError(f'Cannot save variables as a HDF5 file. We only support file with ' f'postfix of ".hdf5" and ".h5". But we got {filename}') - import h5py + import h5py # noqa # check variables check_dict_data(variables, name='variables') @@ -184,7 +184,7 @@ def load_by_h5(filename: str, target, verbose: bool = False): f'postfix of ".hdf5" and ".h5". But we got {filename}') # read data - import h5py + import h5py # noqa load_vars = dict() with h5py.File(filename, "r") as f: for key in f.keys(): diff --git a/brainpy/_src/checkpoints/tests/test_io.py b/brainpy/_src/checkpoints/tests/test_io.py index f8ed80210..36c8f374b 100644 --- a/brainpy/_src/checkpoints/tests/test_io.py +++ b/brainpy/_src/checkpoints/tests/test_io.py @@ -40,18 +40,18 @@ def __init__(self): print(self.net.vars().keys()) print(self.net.vars().unique().keys()) - def test_h5(self): - bp.checkpoints.io.save_as_h5('io_test_tmp.h5', self.net.vars()) - bp.checkpoints.io.load_by_h5('io_test_tmp.h5', self.net, verbose=True) - - bp.checkpoints.io.save_as_h5('io_test_tmp.hdf5', self.net.vars()) - bp.checkpoints.io.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True) - - def test_h5_postfix(self): - with self.assertRaises(ValueError): - bp.checkpoints.io.save_as_h5('io_test_tmp.h52', self.net.vars()) - with self.assertRaises(ValueError): - bp.checkpoints.io.load_by_h5('io_test_tmp.h52', self.net, verbose=True) + # def test_h5(self): + # bp.checkpoints.io.save_as_h5('io_test_tmp.h5', self.net.vars()) + # bp.checkpoints.io.load_by_h5('io_test_tmp.h5', self.net, verbose=True) + # + # bp.checkpoints.io.save_as_h5('io_test_tmp.hdf5', self.net.vars()) + # bp.checkpoints.io.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True) + # + # def test_h5_postfix(self): + # with self.assertRaises(ValueError): + # bp.checkpoints.io.save_as_h5('io_test_tmp.h52', self.net.vars()) + # with self.assertRaises(ValueError): + # bp.checkpoints.io.load_by_h5('io_test_tmp.h52', self.net, verbose=True) def test_npz(self): bp.checkpoints.io.save_as_npz('io_test_tmp.npz', self.net.vars()) @@ -120,18 +120,18 @@ def __init__(self): print(self.net.vars().keys()) print(self.net.vars().unique().keys()) - def test_h5(self): - bp.checkpoints.io.save_as_h5('io_test_tmp.h5', self.net.vars()) - bp.checkpoints.io.load_by_h5('io_test_tmp.h5', self.net, verbose=True) - - bp.checkpoints.io.save_as_h5('io_test_tmp.hdf5', self.net.vars()) - bp.checkpoints.io.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True) - - def test_h5_postfix(self): - with self.assertRaises(ValueError): - bp.checkpoints.io.save_as_h5('io_test_tmp.h52', self.net.vars()) - with self.assertRaises(ValueError): - bp.checkpoints.io.load_by_h5('io_test_tmp.h52', self.net, verbose=True) + # def test_h5(self): + # bp.checkpoints.io.save_as_h5('io_test_tmp.h5', self.net.vars()) + # bp.checkpoints.io.load_by_h5('io_test_tmp.h5', self.net, verbose=True) + # + # bp.checkpoints.io.save_as_h5('io_test_tmp.hdf5', self.net.vars()) + # bp.checkpoints.io.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True) + # + # def test_h5_postfix(self): + # with self.assertRaises(ValueError): + # bp.checkpoints.io.save_as_h5('io_test_tmp.h52', self.net.vars()) + # with self.assertRaises(ValueError): + # bp.checkpoints.io.load_by_h5('io_test_tmp.h52', self.net, verbose=True) def test_npz(self): bp.checkpoints.io.save_as_npz('io_test_tmp.npz', self.net.vars()) diff --git a/brainpy/_src/connect/random_conn.py b/brainpy/_src/connect/random_conn.py index ee98ea135..1f5b1db6d 100644 --- a/brainpy/_src/connect/random_conn.py +++ b/brainpy/_src/connect/random_conn.py @@ -128,8 +128,11 @@ def build_csr(self): return selected_post_ids.astype(get_idx_type()), selected_pre_inptr.astype(get_idx_type()) def build_mat(self): - pre_state = self._jaxrand.uniform(size=(self.pre_num, 1)) < self.pre_ratio - mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) * pre_state + if self.pre_ratio < 1.: + pre_state = self._jaxrand.uniform(size=(self.pre_num, 1)) < self.pre_ratio + mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) * pre_state + else: + mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) mat = bm.asarray(mat) if not self.include_self: bm.fill_diagonal(mat, False) diff --git a/brainpy/_src/dyn/neurons/hh.py b/brainpy/_src/dyn/neurons/hh.py index afb4ab262..2069f4e65 100644 --- a/brainpy/_src/dyn/neurons/hh.py +++ b/brainpy/_src/dyn/neurons/hh.py @@ -180,7 +180,7 @@ def update(self, x=None): return super().update(x) -class HHLTC(NeuDyn): +class HHLTC(HHTypedNeuron): r"""Hodgkin–Huxley neuron model with liquid time constant. **Model Descriptions** @@ -758,7 +758,7 @@ def update(self, x=None): return super().update(x) -class WangBuzsakiHHLTC(NeuDyn): +class WangBuzsakiHHLTC(HHTypedNeuron): r"""Wang-Buzsaki model [9]_, an implementation of a modified Hodgkin-Huxley model with liquid time constant. Each model is described by a single compartment and obeys the current balance equation: diff --git a/brainpy/_src/dyn/projections/aligns.py b/brainpy/_src/dyn/projections/aligns.py index d63033eb7..2dfa2dd14 100644 --- a/brainpy/_src/dyn/projections/aligns.py +++ b/brainpy/_src/dyn/projections/aligns.py @@ -1,7 +1,5 @@ from typing import Optional, Callable, Union -import jax - from brainpy import math as bm, check from brainpy._src.delay import Delay, DelayAccess, delay_identifier, init_delay_by_return from brainpy._src.dynsys import DynamicalSystem, Projection @@ -127,6 +125,7 @@ def __init__( # references self.refs = dict(post=post, out=out) # invisible to ``self.nodes()`` + self.refs['comm'] = comm # unify the access def update(self, x): current = self.comm(x) @@ -218,6 +217,7 @@ def __init__( self.refs = dict(post=post) # invisible to ``self.nodes()`` self.refs['syn'] = post.get_bef_update(self._post_repr).syn self.refs['out'] = post.get_bef_update(self._post_repr).out + self.refs['comm'] = comm # unify the access def update(self, x): current = self.comm(x) @@ -342,6 +342,9 @@ def __init__( self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()`` self.refs['syn'] = post.get_bef_update(self._post_repr).syn # invisible to ``self.node()`` self.refs['out'] = post.get_bef_update(self._post_repr).out # invisible to ``self.node()`` + # unify the access + self.refs['comm'] = comm + self.refs['delay'] = pre.get_aft_update(delay_identifier) def update(self): x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name) @@ -422,9 +425,13 @@ def __init__( post.add_bef_update(self.name, _AlignPost(syn, out)) # reference - self.refs = dict(post=post) # invisible to ``self.nodes()`` + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['post'] = post self.refs['syn'] = post.get_bef_update(self.name).syn self.refs['out'] = post.get_bef_update(self.name).out + # unify the access + self.refs['comm'] = comm def update(self, x): current = self.comm(x) @@ -538,8 +545,15 @@ def __init__( post.add_inp_fun(out_name, out) # references - self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()`` + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['pre'] = pre + self.refs['post'] = post self.refs['out'] = out + # unify the access + self.refs['delay'] = pre.get_aft_update(delay_identifier) + self.refs['comm'] = comm + self.refs['syn'] = syn def update(self): x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name) @@ -655,8 +669,15 @@ def __init__( post.add_inp_fun(out_name, out) # references - self.refs = dict(pre=pre, post=post, out=out, delay=delay_cls) # invisible to ``self.nodes()`` + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['pre'] = pre + self.refs['post'] = post + self.refs['out'] = out + self.refs['delay'] = delay_cls self.refs['syn'] = pre.get_aft_update(self._syn_id).syn + # unify the access + self.refs['comm'] = comm def update(self, x=None): if x is None: @@ -778,9 +799,14 @@ def __init__( post.add_inp_fun(out_name, out) # references - self.refs = dict(pre=pre, post=post) # invisible to `self.nodes()` + self.refs = dict() + # invisible to `self.nodes()` + self.refs['pre'] = pre + self.refs['post'] = post self.refs['syn'] = delay_cls.get_bef_update(self._syn_id).syn self.refs['out'] = out + # unify the access + self.refs['comm'] = comm def update(self): x = _get_return(self.refs['syn'].return_info()) @@ -890,9 +916,15 @@ def __init__( post.add_inp_fun(out_name, out) # references - self.refs = dict(pre=pre, post=post, out=out) # invisible to ``self.nodes()`` + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['pre'] = pre + self.refs['post'] = post + self.refs['out'] = out self.refs['delay'] = delay_cls self.refs['syn'] = syn + # unify the access + self.refs['comm'] = comm def update(self, x=None): if x is None: @@ -1006,8 +1038,15 @@ def __init__( post.add_inp_fun(out_name, out) # references - self.refs = dict(pre=pre, post=post, out=out) # invisible to ``self.nodes()`` + self.refs = dict() + # invisible to ``self.nodes()`` + self.refs['pre'] = pre + self.refs['post'] = post + self.refs['out'] = out self.refs['delay'] = pre.get_aft_update(delay_identifier) + # unify the access + self.refs['syn'] = syn + self.refs['comm'] = comm def update(self): spk = self.refs['delay'].at(self.name) diff --git a/brainpy/_src/math/modes.py b/brainpy/_src/math/modes.py index 5e72ff09c..674035e18 100644 --- a/brainpy/_src/math/modes.py +++ b/brainpy/_src/math/modes.py @@ -61,6 +61,10 @@ class NonBatchingMode(Mode): """ pass + @property + def batch_size(self): + return tuple() + class BatchingMode(Mode): """Batching mode. diff --git a/brainpy/_src/math/object_transform/base.py b/brainpy/_src/math/object_transform/base.py index 851e23776..daa8a55bb 100644 --- a/brainpy/_src/math/object_transform/base.py +++ b/brainpy/_src/math/object_transform/base.py @@ -20,8 +20,12 @@ from brainpy._src.math.object_transform.naming import (get_unique_name, check_name_uniqueness) from brainpy._src.math.object_transform.variables import (Variable, VariableView, TrainVar, - VarList, VarDict) + VarList, VarDict, var_stack_list) +from brainpy._src.math.modes import Mode +from brainpy._src.math.sharding import BATCH_AXIS + +variable_ = None StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys']) __all__ = [ @@ -102,17 +106,91 @@ def __init__(self, name=None): def setattr(self, key: str, value: Any) -> None: super().__setattr__(key, value) + def tracing_variable( + self, + name: str, + init: Union[Callable, Array, jax.Array], + shape: Union[int, Sequence[int]], + batch_or_mode: Union[int, bool, Mode] = None, + batch_axis: int = 0, + axis_names: Optional[Sequence[str]] = None, + batch_axis_name: Optional[str] = BATCH_AXIS, + ) -> Variable: + """Initialize the variable which can be traced during computations and transformations. + + Although this function is designed to initialize tracing variables during computation or compilation, + it can also be used for the initialization of variables before computation and compilation. + + - If the variable has not been instantiated, a :py:class:`~.Variable` will be instantiated. + - If the variable has been created, the further call of this function will return the created variable. + + Here is the usage example:: + + class Example(bm.BrainPyObject): + def fun(self): + # The first time of calling `.fun()`, this line will create a Variable instance. + # If users repeatedly call `.fun()` function, this line will not initialize variables again. + # Instead, it will return the variable has been created. + self.tracing_variable('a', bm.zeros, (10,)) + + # The created variable can be accessed with self.xxx + self.a.value = bm.ones(10) + + # Calling this function again will not reinitialize the + # variable again, Instead, it will return the variable + # that has been created. + a = self.tracing_variable('a', bm.zeros, (10,)) + + Args: + name: str. The variable name. + init: callable, Array. The data to be initialized as a ``Variable``. + batch_or_mode: int, bool, Mode. This is used to specify the batch size of this variable. + If it is a boolean or an instance of ``Mode``, the batch size will be 1. + If it is None, the variable has no batch axis. + shape: int, sequence of int. The shape of the variable. + batch_axis: int. The batch axis, if batch size is given. + axis_names: sequence of str. The name for each axis. These names should match the given ``axes``. + batch_axis_name: str. The name for the batch axis. The name will be used + if ``batch_or_mode`` is given. Default is ``brainpy.math.sharding.BATCH_AXIS``. + + Returns: + The instance of :py:class:`~.Variable`. + """ + # the variable has been created + if hasattr(self, name): + var = getattr(self, name) + if isinstance(var, Variable): + return var + # if var.shape != value.shape: + # raise ValueError( + # f'"{name}" has been used in this class with the shape of {var.shape} (!= {value.shape}). ' + # f'Please assign another name for the initialization of variables ' + # f'tracing during computation and compilation.' + # ) + # if var.dtype != value.dtype: + # raise ValueError( + # f'"{name}" has been used in this class with the dtype of {var.dtype} (!= {value.dtype}). ' + # f'Please assign another name for the initialization of variables ' + # f'tracing during computation and compilation.' + # ) + + global variable_ + if variable_ is None: + from brainpy.initialize import variable_ + with jax.ensure_compile_time_eval(): + value = variable_(init, shape, batch_or_mode, batch_axis, axis_names, batch_axis_name) + value._ready_to_trace = True + self.setattr(name, value) + return value + def __setattr__(self, key: str, value: Any) -> None: """Overwrite `__setattr__` method for changing :py:class:`~.Variable` values. .. versionadded:: 2.3.1 - Parameters - ---------- - key: str - The attribute. - value: Any - The value. + Args: + key: str. The attribute. + value: Any. The value. """ if key in self.__dict__: val = self.__dict__[key] @@ -252,7 +330,7 @@ def vars(self, continue v = getattr(node, k) if isinstance(v, Variable) and not isinstance(v, exclude_types): - gather[f'{node_path}.{k}' if node_path else k] = v + gather[f'{node_path}.{k}' if node_path else k] = v elif isinstance(v, VarList): for i, vv in enumerate(v): if not isinstance(vv, exclude_types): @@ -702,4 +780,3 @@ def __setitem__(self, key, value) -> 'VarDict': node_dict = NodeDict - diff --git a/brainpy/_src/math/object_transform/tests/test_variable.py b/brainpy/_src/math/object_transform/tests/test_variable.py new file mode 100644 index 000000000..ddf7c8d22 --- /dev/null +++ b/brainpy/_src/math/object_transform/tests/test_variable.py @@ -0,0 +1,51 @@ +import brainpy.math as bm +import unittest + + +class TestVar(unittest.TestCase): + def test1(self): + class A(bm.BrainPyObject): + def __init__(self): + super().__init__() + self.a = bm.Variable(1) + self.f1 = bm.jit(self.f) + self.f2 = bm.jit(self.ff) + self.f3 = bm.jit(self.fff) + + def f(self): + b = self.tracing_variable('b', bm.ones, (1,)) + self.a += (b * 2) + return self.a.value + + def ff(self): + self.b += 1. + + def fff(self): + self.f() + self.ff() + self.b *= self.a + return self.b.value + + print() + f_jit = bm.jit(A().f) + f_jit() + self.assertTrue(len(f_jit._dyn_vars) == 2) + + print() + a = A() + self.assertTrue(bm.all(a.f1() == 2.)) + self.assertTrue(len(a.f1._dyn_vars) == 2) + print(a.f2()) + self.assertTrue(len(a.f2._dyn_vars) == 1) + + print() + a = A() + print() + self.assertTrue(bm.allclose(a.f3(), 4.)) + self.assertTrue(len(a.f3._dyn_vars) == 2) + + bm.clear_buffer_memory() + + + + diff --git a/brainpy/_src/math/object_transform/variables.py b/brainpy/_src/math/object_transform/variables.py index f526a6680..06020f4cc 100644 --- a/brainpy/_src/math/object_transform/variables.py +++ b/brainpy/_src/math/object_transform/variables.py @@ -39,6 +39,12 @@ def add(self, var: 'Variable'): if id_ not in self: self[id_] = var self._values[id_] = var._value + # v = var._value + # if isinstance(v, Tracer): + # with jax.ensure_compile_time_eval(): + # v = jnp.zeros_like(v) + # var._value = v + # self._values[id_] = v def collect_values(self): """Collect the value of each variable once again.""" @@ -71,7 +77,7 @@ def dict_data(self) -> dict: """Get all data in the collected variables with a python dict structure.""" new_dict = dict() for id_, elem in tuple(self.items()): - new_dict[id_] = elem.value if isinstance(elem, Array) else elem + new_dict[id_] = elem.value return new_dict def list_data(self) -> list: @@ -108,7 +114,6 @@ def __add__(self, other: dict): new_dict._values.update(other._values) return new_dict - var_stack_list: List[VariableStack] = [] transform_stack: List[Callable] = [] @@ -163,14 +168,11 @@ class Variable(Array): Note that when initializing a `Variable` by the data shape, all values in this `Variable` will be initialized as zeros. - Parameters - ---------- - value_or_size: Shape, Array, int - The value or the size of the value. - dtype: - The type of the data. - batch_axis: optional, int - The batch axis. + Args: + value_or_size: Shape, Array, int. The value or the size of the value. + dtype: Any. The type of the data. + batch_axis: optional, int. The batch axis. + axis_names: sequence of str. The name for each axis. """ __slots__ = ('_value', '_batch_axis', '_ready_to_trace', 'axis_names') @@ -191,7 +193,7 @@ def __init__( else: value = value_or_size - super(Variable, self).__init__(value, dtype=dtype) + super().__init__(value, dtype=dtype) # check batch axis if isinstance(value, Variable): @@ -276,7 +278,6 @@ def value(self, v): v = v self._value = v - def _append_to_stack(self): if self._ready_to_trace: for stack in var_stack_list: @@ -319,7 +320,7 @@ def __init__( axis_names: Optional[Sequence[str]] = None, _ready_to_trace: bool = True ): - super(TrainVar, self).__init__( + super().__init__( value_or_size, dtype=dtype, batch_axis=batch_axis, @@ -342,7 +343,7 @@ def __init__( axis_names: Optional[Sequence[str]] = None, _ready_to_trace: bool = True ): - super(Parameter, self).__init__( + super().__init__( value_or_size, dtype=dtype, batch_axis=batch_axis, @@ -390,7 +391,7 @@ def __init__( self.index = jax.tree_util.tree_map(_as_jax_array_, index, is_leaf=lambda a: isinstance(a, Array)) if not isinstance(value, Variable): raise ValueError('Must be instance of Variable.') - super(VariableView, self).__init__(value.value, batch_axis=value.batch_axis, _ready_to_trace=False) + super().__init__(value.value, batch_axis=value.batch_axis, _ready_to_trace=False) self._value = value def __repr__(self) -> str: diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py index ddd4753a9..eb04c5d2e 100644 --- a/brainpy/_src/math/random.py +++ b/brainpy/_src/math/random.py @@ -1253,6 +1253,19 @@ def split_key(): return DEFAULT.split_key() +def split_keys(n): + """Create multiple seeds from the current seed. This is used + internally by `pmap` and `vmap` to ensure that random numbers + are different in parallel threads. + + Parameters + ---------- + n : int + The number of seeds to generate. + """ + return DEFAULT.split_keys(n) + + def clone_rng(seed_or_key=None, clone: bool = True) -> RandomState: if seed_or_key is None: return DEFAULT.clone() if clone else DEFAULT diff --git a/brainpy/_src/math/sparse/_csr_mv.py b/brainpy/_src/math/sparse/_csr_mv.py index e43965d4d..9a37f0902 100644 --- a/brainpy/_src/math/sparse/_csr_mv.py +++ b/brainpy/_src/math/sparse/_csr_mv.py @@ -81,6 +81,9 @@ def csrmv( indptr = as_jax(indptr) vector = as_jax(vector) + if vector.dtype == jnp.bool_: + vector = as_jax(vector, dtype=data.dtype) + if method == 'cusparse': if jax.default_backend() == 'gpu': if data.shape[0] == 1: diff --git a/brainpy/_src/running/pathos_multiprocessing.py b/brainpy/_src/running/pathos_multiprocessing.py index b58b1691e..1573a541c 100644 --- a/brainpy/_src/running/pathos_multiprocessing.py +++ b/brainpy/_src/running/pathos_multiprocessing.py @@ -18,8 +18,8 @@ from brainpy.errors import PackageMissingError try: - from pathos.helpers import cpu_count - from pathos.multiprocessing import ProcessPool + from pathos.helpers import cpu_count # noqa + from pathos.multiprocessing import ProcessPool # noqa except ModuleNotFoundError: cpu_count = None ProcessPool = None diff --git a/brainpy/math/random.py b/brainpy/math/random.py index ed3fbeea4..dde1f4832 100644 --- a/brainpy/math/random.py +++ b/brainpy/math/random.py @@ -7,6 +7,7 @@ seed as seed, split_key as split_key, + split_keys as split_keys, default_rng as default_rng, # numpy compatibility diff --git a/examples/dynamics_simulation/COBA_parallel.py b/examples/dynamics_simulation/COBA_parallel.py index a0f10de09..45cf81953 100644 --- a/examples/dynamics_simulation/COBA_parallel.py +++ b/examples/dynamics_simulation/COBA_parallel.py @@ -2,10 +2,23 @@ import brainpy as bp import brainpy.math as bm +from jax.experimental.maps import xmap + # bm.set_host_device_count(4) +class ExpJIT(bp.Projection): + def __init__(self, pre_num, post, prob, g_max, tau=5., E=0.): + super().__init__() + self.proj = bp.dyn.ProjAlignPostMg1( + comm=bp.dnn.EventJitFPHomoLinear(pre_num, post.num, prob=prob, weight=g_max), + syn=bp.dyn.Expon.desc(size=post.num, tau=tau, sharding=[bm.sharding.NEU_AXIS]), + out=bp.dyn.COBA.desc(E=E), + post=post + ) + + class EINet1(bp.DynSysGroup): def __init__(self): super().__init__() @@ -13,18 +26,8 @@ def __init__(self): V_initializer=bp.init.Normal(-55., 2.), sharding=[bm.sharding.NEU_AXIS]) self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.E = bp.dyn.ProjAlignPostMg1( - comm=bp.dnn.EventJitFPHomoLinear(3200, 4000, prob=0.02, weight=0.6), - syn=bp.dyn.Expon.desc(size=4000, tau=5., sharding=[bm.sharding.NEU_AXIS]), - out=bp.dyn.COBA.desc(E=0.), - post=self.N - ) - self.I = bp.dyn.ProjAlignPostMg1( - comm=bp.dnn.EventJitFPHomoLinear(800, 4000, prob=0.02, weight=6.7), - syn=bp.dyn.Expon.desc(size=4000, tau=10., sharding=[bm.sharding.NEU_AXIS]), - out=bp.dyn.COBA.desc(E=-80.), - post=self.N - ) + self.E = ExpJIT(3200, self.N, 0.02, 0.6) + self.I = ExpJIT(800, self.N, 0.02, 6.7, E=-80., tau=10.) def update(self, input): spk = self.delay.at('I') @@ -34,6 +37,18 @@ def update(self, input): return self.N.spike.value +class ExpMasked(bp.Projection): + def __init__(self, pre_num, post, prob, g_max, tau=5., E=0.): + super().__init__() + self.proj = bp.dyn.ProjAlignPostMg1( + comm=bp.dnn.MaskedLinear(bp.conn.FixedProb(prob, pre=pre_num, post=post.num), weight=g_max, + sharding=[None, bm.sharding.NEU_AXIS]), + syn=bp.dyn.Expon.desc(size=post.num, tau=tau, sharding=[bm.sharding.NEU_AXIS]), + out=bp.dyn.COBA.desc(E=E), + post=post + ) + + class EINet2(bp.DynSysGroup): def __init__(self): super().__init__() @@ -41,21 +56,79 @@ def __init__(self): V_initializer=bp.init.Normal(-55., 2.), sharding=[bm.sharding.NEU_AXIS]) self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) - self.E = bp.dyn.ProjAlignPostMg1( - comm=bp.dnn.MaskedLinear(bp.conn.FixedProb(0.02, pre=3200, post=4000), weight=0.6, - sharding=[None, bm.sharding.NEU_AXIS]), - syn=bp.dyn.Expon.desc(size=4000, tau=5., sharding=[bm.sharding.NEU_AXIS]), - out=bp.dyn.COBA.desc(E=0.), - post=self.N + self.E = ExpMasked(3200, self.N, 0.02, 0.6) + self.I = ExpMasked(800, self.N, 0.02, 6.7, E=-80., tau=10.) + + def update(self, input): + spk = self.delay.at('I') + self.E(spk[:3200]) + self.I(spk[3200:]) + self.delay(self.N(input)) + return self.N.spike.value + + +class PCSR(bp.dnn.Layer): + def __init__(self, conn, weight, num_shard, transpose=True): + super().__init__() + + self.conn = conn + self.transpose = transpose + self.num_shard = num_shard + + # connection + self.indices = [] + self.inptr = [] + for _ in range(num_shard): + indices, inptr = self.conn.require('csr') + self.indices.append(indices) + self.inptr.append(inptr) + self.indices = bm.asarray(self.indices) + self.inptr = bm.asarray(self.inptr) + + # weight + weight = bp.init.parameter(weight, (self.indices.size,)) + if isinstance(self.mode, bm.TrainingMode): + weight = bm.TrainVar(weight) + self.weight = weight + + def update(self, v): + # ax1 = None if bm.size(self.weight) > 1 else (None, bm.sharding.NEU_AXIS) + mapped = xmap( + self._f, + in_axes=((bm.sharding.NEU_AXIS, None), (bm.sharding.NEU_AXIS, None), (None, )), + out_axes=(bm.sharding.NEU_AXIS, None), + # axis_resources={bm.sharding.NEU_AXIS: bm.sharding.NEU_AXIS}, ) - self.I = bp.dyn.ProjAlignPostMg1( - comm=bp.dnn.MaskedLinear(bp.conn.FixedProb(0.02, pre=800, post=4000), weight=0.6, - sharding=[None, bm.sharding.NEU_AXIS]), - syn=bp.dyn.Expon.desc(size=4000, tau=10., sharding=[bm.sharding.NEU_AXIS]), - out=bp.dyn.COBA.desc(E=-80.), - post=self.N + r = mapped(self.indices, self.inptr, v) + return r.flatten() + + def _f(self, indices, indptr, x): + return bm.event.csrmv(self.weight, indices, indptr, x, + shape=(self.conn.pre_num, self.conn.post_num // self.num_shard), + transpose=self.transpose) + + +class ExpMasked2(bp.Projection): + def __init__(self, pre_num, post, prob, g_max, tau=5., E=0.): + super().__init__() + self.proj = bp.dyn.ProjAlignPostMg1( + comm=PCSR(bp.conn.FixedProb(prob, pre=pre_num, post=post.num), weight=g_max, num_shard=4), + syn=bp.dyn.Expon.desc(size=post.num, tau=tau, sharding=[bm.sharding.NEU_AXIS]), + out=bp.dyn.COBA.desc(E=E), + post=post ) + +class EINet3(bp.DynSysGroup): + def __init__(self): + super().__init__() + self.N = bp.dyn.LifRefLTC(4000, V_rest=-60., V_th=-50., V_reset=-60., tau=20., tau_ref=5., + V_initializer=bp.init.Normal(-55., 2.), + sharding=[bm.sharding.NEU_AXIS]) + self.delay = bp.VarDelay(self.N.spike, entries={'I': None}) + self.E = ExpMasked2(3200, self.N, 0.02, 0.6) + self.I = ExpMasked2(800, self.N, 0.02, 6.7, E=-80., tau=10.) + def update(self, input): spk = self.delay.at('I') self.E(spk[:3200]) @@ -64,14 +137,44 @@ def update(self, input): return self.N.spike.value -@bm.jit -def run(indexes): - return bm.for_loop(lambda i: model.step_run(i, 20.), indexes) +def try_ei_net1(): + @bm.jit + def run(indexes): + return bm.for_loop(lambda i: model.step_run(i, 20.), indexes) + + with bm.sharding.device_mesh(jax.devices(), [bm.sharding.NEU_AXIS]): + model = EINet1() + indices = bm.arange(1000) + spks = run(indices) + bp.visualize.raster_plot(indices, spks, show=True) + + +def try_ei_net2(): + @bm.jit + def run(indexes): + return bm.for_loop(lambda i: model.step_run(i, 20.), indexes) + + with bm.sharding.device_mesh(jax.devices(), [bm.sharding.NEU_AXIS]): + model = EINet2() + indices = bm.arange(1000) + spks = run(indices) + bp.visualize.raster_plot(indices, spks, show=True) + + + +def try_ei_net3(): + @bm.jit + def run(indexes): + return bm.for_loop(lambda i: model.step_run(i, 20.), indexes) + with bm.sharding.device_mesh(jax.devices(), [bm.sharding.NEU_AXIS]): + model = EINet3() + indices = bm.arange(1000) + spks = run(indices) + bp.visualize.raster_plot(indices, spks, show=True) -with bm.sharding.device_mesh(jax.devices(), [bm.sharding.NEU_AXIS]): - model = EINet2() - indices = bm.arange(1000) - spks = run(indices) -bp.visualize.raster_plot(indices, spks, show=True) +if __name__ == '__main__': + # try_ei_net1() + # try_ei_net2() + try_ei_net3() diff --git a/requirements-dev.txt b/requirements-dev.txt index d8e87ac5f..126f0bd27 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,8 +6,7 @@ jax>=0.4.1 jaxlib>=0.4.1 scipy>=1.1.0 brainpylib -h5py -pathos +numba # test requirements pytest diff --git a/requirements-doc.txt b/requirements-doc.txt index dc67a4b04..d41a8cf41 100644 --- a/requirements-doc.txt +++ b/requirements-doc.txt @@ -6,6 +6,7 @@ jax>=0.4.1 matplotlib>=3.4 jaxlib>=0.4.1 scipy>=1.1.0 +numba # document requirements pandoc diff --git a/requirements.txt b/requirements.txt index d8343cde7..74db0a68a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ numpy jax>=0.4.1 tqdm msgpack +numba \ No newline at end of file