Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support tracing Variable during computation and compilation by using tracing_variable() function #472

Merged
merged 11 commits into from
Sep 9, 2023
4 changes: 2 additions & 2 deletions brainpy/_src/checkpoints/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def save_as_h5(filename: str, variables: dict):
raise ValueError(f'Cannot save variables as a HDF5 file. We only support file with '
f'postfix of ".hdf5" and ".h5". But we got {filename}')

import h5py
import h5py # noqa

# check variables
check_dict_data(variables, name='variables')
Expand Down Expand Up @@ -184,7 +184,7 @@ def load_by_h5(filename: str, target, verbose: bool = False):
f'postfix of ".hdf5" and ".h5". But we got {filename}')

# read data
import h5py
import h5py # noqa
load_vars = dict()
with h5py.File(filename, "r") as f:
for key in f.keys():
Expand Down
48 changes: 24 additions & 24 deletions brainpy/_src/checkpoints/tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,18 @@ def __init__(self):
print(self.net.vars().keys())
print(self.net.vars().unique().keys())

def test_h5(self):
bp.checkpoints.io.save_as_h5('io_test_tmp.h5', self.net.vars())
bp.checkpoints.io.load_by_h5('io_test_tmp.h5', self.net, verbose=True)

bp.checkpoints.io.save_as_h5('io_test_tmp.hdf5', self.net.vars())
bp.checkpoints.io.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True)

def test_h5_postfix(self):
with self.assertRaises(ValueError):
bp.checkpoints.io.save_as_h5('io_test_tmp.h52', self.net.vars())
with self.assertRaises(ValueError):
bp.checkpoints.io.load_by_h5('io_test_tmp.h52', self.net, verbose=True)
# def test_h5(self):
# bp.checkpoints.io.save_as_h5('io_test_tmp.h5', self.net.vars())
# bp.checkpoints.io.load_by_h5('io_test_tmp.h5', self.net, verbose=True)
#
# bp.checkpoints.io.save_as_h5('io_test_tmp.hdf5', self.net.vars())
# bp.checkpoints.io.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True)
#
# def test_h5_postfix(self):
# with self.assertRaises(ValueError):
# bp.checkpoints.io.save_as_h5('io_test_tmp.h52', self.net.vars())
# with self.assertRaises(ValueError):
# bp.checkpoints.io.load_by_h5('io_test_tmp.h52', self.net, verbose=True)

def test_npz(self):
bp.checkpoints.io.save_as_npz('io_test_tmp.npz', self.net.vars())
Expand Down Expand Up @@ -120,18 +120,18 @@ def __init__(self):
print(self.net.vars().keys())
print(self.net.vars().unique().keys())

def test_h5(self):
bp.checkpoints.io.save_as_h5('io_test_tmp.h5', self.net.vars())
bp.checkpoints.io.load_by_h5('io_test_tmp.h5', self.net, verbose=True)

bp.checkpoints.io.save_as_h5('io_test_tmp.hdf5', self.net.vars())
bp.checkpoints.io.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True)

def test_h5_postfix(self):
with self.assertRaises(ValueError):
bp.checkpoints.io.save_as_h5('io_test_tmp.h52', self.net.vars())
with self.assertRaises(ValueError):
bp.checkpoints.io.load_by_h5('io_test_tmp.h52', self.net, verbose=True)
# def test_h5(self):
# bp.checkpoints.io.save_as_h5('io_test_tmp.h5', self.net.vars())
# bp.checkpoints.io.load_by_h5('io_test_tmp.h5', self.net, verbose=True)
#
# bp.checkpoints.io.save_as_h5('io_test_tmp.hdf5', self.net.vars())
# bp.checkpoints.io.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True)
#
# def test_h5_postfix(self):
# with self.assertRaises(ValueError):
# bp.checkpoints.io.save_as_h5('io_test_tmp.h52', self.net.vars())
# with self.assertRaises(ValueError):
# bp.checkpoints.io.load_by_h5('io_test_tmp.h52', self.net, verbose=True)

def test_npz(self):
bp.checkpoints.io.save_as_npz('io_test_tmp.npz', self.net.vars())
Expand Down
7 changes: 5 additions & 2 deletions brainpy/_src/connect/random_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,11 @@ def build_csr(self):
return selected_post_ids.astype(get_idx_type()), selected_pre_inptr.astype(get_idx_type())

def build_mat(self):
pre_state = self._jaxrand.uniform(size=(self.pre_num, 1)) < self.pre_ratio
mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) * pre_state
if self.pre_ratio < 1.:
pre_state = self._jaxrand.uniform(size=(self.pre_num, 1)) < self.pre_ratio
mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) * pre_state
else:
mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob)
mat = bm.asarray(mat)
if not self.include_self:
bm.fill_diagonal(mat, False)
Expand Down
4 changes: 2 additions & 2 deletions brainpy/_src/dyn/neurons/hh.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def update(self, x=None):
return super().update(x)


class HHLTC(NeuDyn):
class HHLTC(HHTypedNeuron):
r"""Hodgkin–Huxley neuron model with liquid time constant.

**Model Descriptions**
Expand Down Expand Up @@ -758,7 +758,7 @@ def update(self, x=None):
return super().update(x)


class WangBuzsakiHHLTC(NeuDyn):
class WangBuzsakiHHLTC(HHTypedNeuron):
r"""Wang-Buzsaki model [9]_, an implementation of a modified Hodgkin-Huxley model with liquid time constant.

Each model is described by a single compartment and obeys the current balance equation:
Expand Down
55 changes: 47 additions & 8 deletions brainpy/_src/dyn/projections/aligns.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Optional, Callable, Union

import jax

from brainpy import math as bm, check
from brainpy._src.delay import Delay, DelayAccess, delay_identifier, init_delay_by_return
from brainpy._src.dynsys import DynamicalSystem, Projection
Expand Down Expand Up @@ -127,6 +125,7 @@ def __init__(

# references
self.refs = dict(post=post, out=out) # invisible to ``self.nodes()``
self.refs['comm'] = comm # unify the access

def update(self, x):
current = self.comm(x)
Expand Down Expand Up @@ -218,6 +217,7 @@ def __init__(
self.refs = dict(post=post) # invisible to ``self.nodes()``
self.refs['syn'] = post.get_bef_update(self._post_repr).syn
self.refs['out'] = post.get_bef_update(self._post_repr).out
self.refs['comm'] = comm # unify the access

def update(self, x):
current = self.comm(x)
Expand Down Expand Up @@ -342,6 +342,9 @@ def __init__(
self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()``
self.refs['syn'] = post.get_bef_update(self._post_repr).syn # invisible to ``self.node()``
self.refs['out'] = post.get_bef_update(self._post_repr).out # invisible to ``self.node()``
# unify the access
self.refs['comm'] = comm
self.refs['delay'] = pre.get_aft_update(delay_identifier)

def update(self):
x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name)
Expand Down Expand Up @@ -422,9 +425,13 @@ def __init__(
post.add_bef_update(self.name, _AlignPost(syn, out))

# reference
self.refs = dict(post=post) # invisible to ``self.nodes()``
self.refs = dict()
# invisible to ``self.nodes()``
self.refs['post'] = post
self.refs['syn'] = post.get_bef_update(self.name).syn
self.refs['out'] = post.get_bef_update(self.name).out
# unify the access
self.refs['comm'] = comm

def update(self, x):
current = self.comm(x)
Expand Down Expand Up @@ -538,8 +545,15 @@ def __init__(
post.add_inp_fun(out_name, out)

# references
self.refs = dict(pre=pre, post=post) # invisible to ``self.nodes()``
self.refs = dict()
# invisible to ``self.nodes()``
self.refs['pre'] = pre
self.refs['post'] = post
self.refs['out'] = out
# unify the access
self.refs['delay'] = pre.get_aft_update(delay_identifier)
self.refs['comm'] = comm
self.refs['syn'] = syn

def update(self):
x = self.refs['pre'].get_aft_update(delay_identifier).at(self.name)
Expand Down Expand Up @@ -655,8 +669,15 @@ def __init__(
post.add_inp_fun(out_name, out)

# references
self.refs = dict(pre=pre, post=post, out=out, delay=delay_cls) # invisible to ``self.nodes()``
self.refs = dict()
# invisible to ``self.nodes()``
self.refs['pre'] = pre
self.refs['post'] = post
self.refs['out'] = out
self.refs['delay'] = delay_cls
self.refs['syn'] = pre.get_aft_update(self._syn_id).syn
# unify the access
self.refs['comm'] = comm

def update(self, x=None):
if x is None:
Expand Down Expand Up @@ -778,9 +799,14 @@ def __init__(
post.add_inp_fun(out_name, out)

# references
self.refs = dict(pre=pre, post=post) # invisible to `self.nodes()`
self.refs = dict()
# invisible to `self.nodes()`
self.refs['pre'] = pre
self.refs['post'] = post
self.refs['syn'] = delay_cls.get_bef_update(self._syn_id).syn
self.refs['out'] = out
# unify the access
self.refs['comm'] = comm

def update(self):
x = _get_return(self.refs['syn'].return_info())
Expand Down Expand Up @@ -890,9 +916,15 @@ def __init__(
post.add_inp_fun(out_name, out)

# references
self.refs = dict(pre=pre, post=post, out=out) # invisible to ``self.nodes()``
self.refs = dict()
# invisible to ``self.nodes()``
self.refs['pre'] = pre
self.refs['post'] = post
self.refs['out'] = out
self.refs['delay'] = delay_cls
self.refs['syn'] = syn
# unify the access
self.refs['comm'] = comm

def update(self, x=None):
if x is None:
Expand Down Expand Up @@ -1006,8 +1038,15 @@ def __init__(
post.add_inp_fun(out_name, out)

# references
self.refs = dict(pre=pre, post=post, out=out) # invisible to ``self.nodes()``
self.refs = dict()
# invisible to ``self.nodes()``
self.refs['pre'] = pre
self.refs['post'] = post
self.refs['out'] = out
self.refs['delay'] = pre.get_aft_update(delay_identifier)
# unify the access
self.refs['syn'] = syn
self.refs['comm'] = comm

def update(self):
spk = self.refs['delay'].at(self.name)
Expand Down
68 changes: 59 additions & 9 deletions brainpy/_src/math/object_transform/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from brainpy._src.math.object_transform.naming import (get_unique_name,
check_name_uniqueness)
from brainpy._src.math.object_transform.variables import (Variable, VariableView, TrainVar,
VarList, VarDict)
VarList, VarDict, var_stack_list)

StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])

Expand Down Expand Up @@ -102,17 +102,68 @@ def __init__(self, name=None):
def setattr(self, key: str, value: Any) -> None:
super().__setattr__(key, value)

def tracing_variable(self, name: str, value: Union[jax.Array, Array]) -> Variable:
"""Initialize and get the variable which can be traced during computation.

Although this function is designed to initialize tracing variables during computation or compilation,
it can also be used for initialization of variables before or after computation and compilation.

- If ``name`` has been used in this object, a ``KeyError`` will be raised.
- If the variable has not been instantiated, the given ``value`` will be used to
instantiate a :py:class:`~.Variable`.
- If the variable has been created, the further call of this function will
refresh the value of the variable with the given ``value``.

Here is the usage example::

class Example(bm.BrainPyObject):
def fun(self):
# this line will create a Variable instance
self.tracing_variable('a', bm.zeros(10))

# calling this function again will assign a different value
# to the created Variable instance
self.tracing_variable('a', bm.random.random(10))

Args:
name: str. The variable name.
value: Array. The data of the in-trace variable. It can also be the instance of
:py:class:`~.Variable`, so that users can control the property of the created
variable instance. If an ``Array`` is provided, then it will be instantiated
as a :py:class:`~.Variable`.

Returns:
The instance of :py:class:`~.Variable`.
"""
# the variable has been created
if hasattr(self, name):
var = getattr(self, name)
if isinstance(var, Variable):
var.value = value
return var

# create the variable
if not isinstance(value, Variable):
value = Variable(value)
value._ready_to_trace = True
if len(var_stack_list) > 0 and isinstance(value._value, jax.core.Tracer):
with jax.ensure_compile_time_eval():
value._value = jax.numpy.zeros_like(value._value)
self.setattr(name, value)
# if not isinstance(var, Variable):
# raise KeyError(f'"{name}" has been used in this class. Please assign '
# f'another name for the initialization of variables '
# f'tracing during computation and compilation.')
return value

def __setattr__(self, key: str, value: Any) -> None:
"""Overwrite `__setattr__` method for changing :py:class:`~.Variable` values.

.. versionadded:: 2.3.1

Parameters
----------
key: str
The attribute.
value: Any
The value.
Args:
key: str. The attribute.
value: Any. The value.
"""
if key in self.__dict__:
val = self.__dict__[key]
Expand Down Expand Up @@ -252,7 +303,7 @@ def vars(self,
continue
v = getattr(node, k)
if isinstance(v, Variable) and not isinstance(v, exclude_types):
gather[f'{node_path}.{k}' if node_path else k] = v
gather[f'{node_path}.{k}' if node_path else k] = v
elif isinstance(v, VarList):
for i, vv in enumerate(v):
if not isinstance(vv, exclude_types):
Expand Down Expand Up @@ -702,4 +753,3 @@ def __setitem__(self, key, value) -> 'VarDict':


node_dict = NodeDict

51 changes: 51 additions & 0 deletions brainpy/_src/math/object_transform/tests/test_variable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import brainpy.math as bm
import unittest


class TestVar(unittest.TestCase):
def test1(self):
class A(bm.BrainPyObject):
def __init__(self):
super().__init__()
self.a = bm.Variable(1)
self.f1 = bm.jit(self.f)
self.f2 = bm.jit(self.ff)
self.f3 = bm.jit(self.fff)

def f(self):
b = self.tracing_variable('b', bm.ones(1, ))
self.a += (b * 2)
return self.a.value

def ff(self):
self.b += 1.

def fff(self):
self.f()
self.ff()
self.b *= self.a
return self.b.value

print()
f_jit = bm.jit(A().f)
f_jit()
self.assertTrue(len(f_jit._dyn_vars) == 2)

print()
a = A()
self.assertTrue(bm.all(a.f1() == 2.))
self.assertTrue(len(a.f1._dyn_vars) == 2)
print(a.f2())
self.assertTrue(len(a.f2._dyn_vars) == 1)

print()
a = A()
print()
self.assertTrue(bm.allclose(a.f3(), 4.))
self.assertTrue(len(a.f3._dyn_vars) == 2)

bm.clear_buffer_memory()




Loading
Loading