Skip to content

Commit

Permalink
Merge pull request #316 from chaoming0625/master
Browse files Browse the repository at this point in the history
Ready for publish
  • Loading branch information
chaoming0625 authored Dec 29, 2022
2 parents dbdc5c6 + 7cf0ad3 commit 2e2d6f4
Show file tree
Hide file tree
Showing 9 changed files with 300 additions and 64 deletions.
3 changes: 1 addition & 2 deletions brainpy/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,9 @@
from . import surrogate
from .surrogate.compt import *

# JAX transformations for Variable and class objects
# Variable and Objects for object-oriented JAX transformations
from .object_transform import *


# environment settings
from .modes import *
from .environment import *
Expand Down
9 changes: 5 additions & 4 deletions brainpy/math/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@


import warnings
from typing import Optional, Tuple
from typing import Optional, Tuple as TupleType

import numpy as np
from jax import numpy as jnp
from jax.tree_util import register_pytree_node


from brainpy.errors import MathError

__all__ = [
Expand Down Expand Up @@ -997,7 +998,7 @@ def __init__(
f'but the batch axis is set to be {batch_axis}.')

@property
def shape_nb(self) -> Tuple[int, ...]:
def shape_nb(self) -> TupleType[int, ...]:
"""Shape without batch axis."""
shape = list(self.value.shape)
if self.batch_axis is not None:
Expand Down Expand Up @@ -1562,7 +1563,6 @@ class BatchVariable(Variable):
pass



class VariableView(Variable):
"""A view of a Variable instance.
Expand Down Expand Up @@ -1742,7 +1742,7 @@ def _jaxarray_unflatten(aux_data, flat_contents):


register_pytree_node(Array,
lambda t: ((t.value,), (t._transform_context, )),
lambda t: ((t.value,), (t._transform_context,)),
_jaxarray_unflatten)

register_pytree_node(Variable,
Expand All @@ -1756,3 +1756,4 @@ def _jaxarray_unflatten(aux_data, flat_contents):
register_pytree_node(Parameter,
lambda t: ((t.value,), None),
lambda aux_data, flat_contents: Parameter(*flat_contents))

6 changes: 6 additions & 0 deletions brainpy/math/object_transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,15 @@
+ controls.__all__
+ jit.__all__
+ function.__all__
+ base_object.__all__
+ base_transform.__all__
+ collector.__all__
)

from .autograd import *
from .controls import *
from .jit import *
from .function import *
from .base_object import *
from .base_transform import *
from .collector import *
131 changes: 119 additions & 12 deletions brainpy/math/object_transform/base_object.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,39 @@
# -*- coding: utf-8 -*-

import os
import logging
import warnings
from collections import namedtuple
from typing import Any, Tuple, Callable, Sequence, Dict, Union

import jax
import numpy as np
from jax._src.tree_util import _registry
from jax.tree_util import register_pytree_node
from jax.tree_util import register_pytree_node_class
from jax.util import safe_zip

from brainpy import errors
from .collector import Collector, ArrayCollector
from ..ndarray import Variable, VariableView, TrainVar
from ..ndarray import (Array,
Variable,
VariableView,
TrainVar)

StateLoadResult = namedtuple('StateLoadResult', ['missing_keys', 'unexpected_keys'])


__all__ = [
'check_name_uniqueness',
'get_unique_name',
'clear_name_cache',
# naming
'check_name_uniqueness', 'get_unique_name', 'clear_name_cache',

# objects
'BrainPyObject', 'Base', 'FunAsObject',

# variables
'numerical_seq', 'object_seq',
'numerical_dict', 'object_dict',
]

logger = logging.getLogger('brainpy.brainpy_object')

_name2id = dict()
_typed_names = {}
Expand Down Expand Up @@ -59,7 +72,7 @@ def clear_name_cache(ignore_warn=False):
_name2id.clear()
_typed_names.clear()
if not ignore_warn:
logger.warning(f'All named models and their ids are cleared.')
warnings.warn(f'All named models and their ids are cleared.', UserWarning)


class BrainPyObject(object):
Expand All @@ -78,6 +91,11 @@ class BrainPyObject(object):
_excluded_vars = ()

def __init__(self, name=None):
super().__init__()
cls = self.__class__
if cls not in _registry:
register_pytree_node_class(cls)

# check whether the object has a unique name.
self._name = None
self._name = self.unique_name(name=name)
Expand All @@ -91,15 +109,17 @@ def __init__(self, name=None):
# which cannot be accessed by self.xxx
self.implicit_nodes = Collector()

def __setattr__(self, key, value) -> None:
"""Overwrite __setattr__ method for non-changeable Variable setting.
def __setattr__(self, key: str, value: Any) -> None:
"""Overwrite `__setattr__` method for change Variable values.
.. versionadded:: 2.3.1
Parameters
----------
key: str
The attribute.
value: Any
The value.
"""
if key in self.__dict__:
val = self.__dict__[key]
Expand All @@ -109,19 +129,24 @@ def __setattr__(self, key, value) -> None:
super().__setattr__(key, value)

def tree_flatten(self):
"""
"""Flattens the object as a PyTree.
The flattening order is determined by attributes added order.
.. versionadded:: 2.3.1
Returns
-------
res: tuple
A tuple of dynamical values and static values.
"""
dts = (BrainPyObject,) + tuple(dynamical_types)
dynamic_names = []
dynamic_values = []
static_names = []
static_values = []
for k, v in self.__dict__.items():
if isinstance(v, (ArrayCollector, BrainPyObject, Variable)):
if isinstance(v, dts):
dynamic_names.append(k)
dynamic_values.append(v)
else:
Expand Down Expand Up @@ -531,3 +556,85 @@ def __repr__(self) -> str:
node_string = ", \n".join(nodes)
return (f'{name}(nodes=[{node_string}],\n' +
" " * (len(name) + 1) + f'num_of_vars={len(self.implicit_vars)})')


class numerical_seq(list):
"""A list to represent a dynamically changed numerical
sequence in which its element can be changed during JIT compilation.
.. note::
The element must be numerical, like ``bool``, ``int``, ``float``,
``jax.Array``, ``numpy.ndarray``, ``brainpy.math.Array``.
"""
def append(self, element):
if not isinstance(element, (bool, int, float, jax.Array, Array, np.ndarray)):
raise TypeError(f'Each element should be a numerical value.')

def extend(self, iterable) -> None:
for element in iterable:
self.append(element)


register_pytree_node(numerical_seq,
lambda x: (tuple(x), ()),
lambda _, values: numerical_seq(values))


class object_seq(list):
"""A list to represent a sequence of :py:class:`~.BrainPyObject`.
.. note::
The element must be :py:class:`~.BrainPyObject`.
"""
def append(self, element):
if not isinstance(element, BrainPyObject):
raise TypeError(f'Only support {BrainPyObject.__name__}')

def extend(self, iterable) -> None:
for element in iterable:
self.append(element)


register_pytree_node(object_seq,
lambda x: (tuple(x), ()),
lambda _, values: object_seq(values))


class numerical_dict(dict):
"""A dict to represent a dynamically changed numerical
dictionary in which its element can be changed during JIT compilation.
.. note::
Each key must be a string, and each value must be numerical, including
``bool``, ``int``, ``float``, ``jax.Array``, ``numpy.ndarray``,
``brainpy.math.Array``.
"""
def update(self, *args, **kwargs) -> 'numerical_dict':
super().update(*args, **kwargs)
return self


register_pytree_node(numerical_dict,
lambda x: (tuple(x.values()), tuple(x.keys())),
lambda keys, values: numerical_dict(safe_zip(keys, values)))


class object_dict(dict):
"""A dict to represent a dictionary of :py:class:`~.BrainPyObject`.
.. note::
Each key must be a string, and each value must be :py:class:`~.BrainPyObject`.
"""
def update(self, *args, **kwargs) -> 'object_dict':
super().update(*args, **kwargs)
return self


register_pytree_node(object_dict,
lambda x: (tuple(x.values()), tuple(x.keys())),
lambda keys, values: object_dict(safe_zip(keys, values)))

dynamical_types = [Variable,
numerical_seq, numerical_dict,
object_seq, object_dict]

Loading

0 comments on commit 2e2d6f4

Please sign in to comment.