Skip to content

Commit

Permalink
Update code and doc (#18)
Browse files Browse the repository at this point in the history
* update code

* update documentation
  • Loading branch information
chaoming0625 authored Oct 27, 2024
1 parent d654c73 commit 6d63d95
Show file tree
Hide file tree
Showing 13 changed files with 1,186 additions and 38 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@

# A ``State``-based Transformation System for Brain Dynamics Programming
# A ``State``-based Transformation System for Program Compilation and Augmentation



<p align="center">
<img alt="Header image of brainstate." src="https://github.com/brainpy/brainstate/blob/main/docs/_static/brainstate.png" width=50%>
<img alt="Header image of brainstate." src="https://github.com/chaobrain/brainstate/blob/main/docs/_static/brainstate.png" width=50%>
</p>



<p align="center">
<a href="https://pypi.org/project/brainstate/"><img alt="Supported Python Version" src="https://img.shields.io/pypi/pyversions/brainstate"></a>
<a href="https://github.com/brainpy/brainstate/blob/main/LICENSE"><img alt="LICENSE" src="https://img.shields.io/badge/License-Apache%202.0-blue.svg"></a>
<a href="https://github.com/chaobrain/brainstate/blob/main/LICENSE"><img alt="LICENSE" src="https://img.shields.io/badge/License-Apache%202.0-blue.svg"></a>
<a href='https://brainstate.readthedocs.io/en/latest/?badge=latest'>
<img src='https://readthedocs.org/projects/brainstate/badge/?version=latest' alt='Documentation Status' />
</a>
<a href="https://badge.fury.io/py/brainstate"><img alt="PyPI version" src="https://badge.fury.io/py/brainstate.svg"></a>
<a href="https://github.com/brainpy/brainstate/actions/workflows/CI.yml"><img alt="Continuous Integration" src="https://github.com/brainpy/brainstate/actions/workflows/CI.yml/badge.svg"></a>
<a href="https://github.com/chaobrain/brainstate/actions/workflows/CI.yml"><img alt="Continuous Integration" src="https://github.com/chaobrain/brainstate/actions/workflows/CI.yml/badge.svg"></a>
</p>


Expand Down
35 changes: 35 additions & 0 deletions brainstate/augment/_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@
from brainstate.graph import (NodeStates, graph_to_tree, tree_to_graph)
from brainstate.typing import Missing, Filter
from brainstate.util import NestedDict
from ._random import restore_rngs

__all__ = [
'StateAxes',
'vmap',
'vmap_with_default_rng',
'pmap',
'mini_vmap',
'mini_pmap',
Expand Down Expand Up @@ -216,6 +218,39 @@ def vmap(
rng_restore=rng_restore, )


def vmap_with_default_rng(
fn: F | Missing = Missing(),
*,
in_axes: int | None | Sequence[Any] = 0,
out_axes: Any = 0,
axis_name: AxisName | None = None,
axis_size: int | None = None,
spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None,
) -> F | Callable[[F], F]:
if isinstance(fn, Missing):
return functools.partial(
vmap_with_default_rng,
in_axes=in_axes,
out_axes=out_axes,
axis_name=axis_name,
axis_size=axis_size,
spmd_axis_name=spmd_axis_name,
) # type: ignore[return-value]

return restore_rngs(
_map_transform(
jax.vmap,
fn,
in_axes=in_axes,
out_axes=out_axes,
axis_name=axis_name,
axis_size=axis_size,
spmd_axis_name=spmd_axis_name

)
)


def pmap(
fn: Callable[[NestedDict, ...], Any] | Missing = Missing(),
axis_name: Optional[AxisName] = None,
Expand Down
4 changes: 2 additions & 2 deletions brainstate/graph/_graph_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,7 +1222,7 @@ def nodes(
node_maps = tuple(FlattedDict(flat_node) for flat_node in flat_nodes)
if num_filters < 2:
return node_maps[0]
return node_maps
return node_maps[:num_filters]


def _states_generator(node, allowed_hierarchy) -> Iterable[Tuple[PathParts, State]]:
Expand Down Expand Up @@ -1268,7 +1268,7 @@ def states(
state_maps = tuple(FlattedDict(flat_state) for flat_state in flat_states)
if num_filters < 2:
return state_maps[0]
return state_maps
return state_maps[:num_filters]


@overload
Expand Down
27 changes: 25 additions & 2 deletions brainstate/graph/_graph_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,7 +557,7 @@ def run(self) -> None:
thread.join()


class TestFlatten(unittest.TestCase):
class TestGraphOperation(unittest.TestCase):
def test1(self):
class MyNode(bst.graph.Node):
def __init__(self):
Expand Down Expand Up @@ -637,7 +637,7 @@ def __init__(self):
assert not hasattr(model.b, 'V')
# print(model.states())

def test2(self):
def test_treefy_split(self):
class MLP(bst.graph.Node):
def __init__(self, din: int, dmid: int, dout: int, n_layer: int = 3):
self.input = bst.nn.Linear(din, dmid)
Expand All @@ -661,6 +661,29 @@ def __call__(self, x):
# nest_states = states.to_nest()
# print(nest_states)

def test_states(self):
class MLP(bst.graph.Node):
def __init__(self, din: int, dmid: int, dout: int, n_layer: int = 3):
self.input = bst.nn.Linear(din, dmid)
self.layers = [bst.nn.Linear(dmid, dmid) for _ in range(n_layer)]
self.output = bst.nn.LIF(dout)

def __call__(self, x):
x = bst.functional.relu(self.input(x))
for layer in self.layers:
x = bst.functional.relu(layer(x))
return self.output(x)

model = bst.nn.init_all_states(MLP(2, 1, 3))
states = bst.graph.states(model)
print(states)
nest_states = states.to_nest()
print(nest_states)

params, others = bst.graph.states(model, bst.ParamState, bst.ShortTermState)
print(params)
print(others)


if __name__ == '__main__':
absltest.main()
17 changes: 9 additions & 8 deletions brainstate/nn/_collective_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

from brainstate._utils import set_module_as
from ._module import Module
from brainstate.graph import nodes

# the maximum order
MAX_ORDER = 10
Expand Down Expand Up @@ -85,12 +86,12 @@ def init_all_states(target: T, *args, exclude=None, **kwargs) -> T:
"""
nodes_with_order = []

nodes = target.nodes(Module)
nodes_ = nodes(target).filter(Module)
if exclude is not None:
nodes = nodes - nodes.filter(exclude)
nodes_ = nodes_ - nodes_.filter(exclude)

# reset node whose `init_state` has no `call_order`
for node in list(nodes.values()):
for node in list(nodes_.values()):
if hasattr(node.init_state, 'call_order'):
nodes_with_order.append(node)
else:
Expand All @@ -117,7 +118,7 @@ def reset_all_states(target: Module, *args, **kwargs) -> Module:
nodes_with_order = []

# reset node whose `init_state` has no `call_order`
for node in list(target.nodes(Module).values()):
for path, node in nodes(target).items():
if hasattr(node.reset_state, 'call_order'):
nodes_with_order.append(node)
else:
Expand Down Expand Up @@ -149,12 +150,12 @@ def load_all_states(target: Module, state_dict: Dict, **kwargs):
"""
missing_keys = []
unexpected_keys = []
for name, node in target.nodes().items():
r = node.load_state(state_dict[name], **kwargs)
for path, node in nodes(target).items():
r = node.load_state(state_dict[path], **kwargs)
if r is not None:
missing, unexpected = r
missing_keys.extend([f'{name}.{key}' for key in missing])
unexpected_keys.extend([f'{name}.{key}' for key in unexpected])
missing_keys.extend([f'{path}.{key}' for key in missing])
unexpected_keys.extend([f'{path}.{key}' for key in unexpected])
return StateLoadResult(missing_keys, unexpected_keys)


Expand Down
12 changes: 7 additions & 5 deletions docs/index.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
``brainstate`` documentation
============================

`brainstate <https://github.com/brainpy/brainstate>`_ implements a ``State``-based Transformation System for Program Compilation and Augmentation.
`brainstate <https://github.com/chaobrain/brainstate>`_ implements a ``State``-based Transformation System for Program Compilation and Augmentation.

``BrainState`` is specifically designed to work with models that have states, including rate-based recurrent neural networks, spiking neural networks, and other dynamical systems.

Expand Down Expand Up @@ -37,7 +37,7 @@ Features

.. div:: sd-font-normal

``BrainState`` enables `event-driven computation <./apis/event.html>`__ for spiking neural networks,
``BrainState`` enables `event-driven computation <./apis/event.rst>`__ for spiking neural networks,
and thus obtains unprecedented performance on CPU and GPU devices.


Expand All @@ -52,7 +52,7 @@ Features

.. div:: sd-font-normal

``BrainState`` supports `program compilation <./apis/compile.html>`__ (such as just-in-time compilation) with its `state-based <./apis/brainstate.html>`__ IR construction.
``BrainState`` supports `program compilation <./apis/compile.rst>`__ (such as just-in-time compilation) with its `state-based <./apis/brainstate.rst>`__ IR construction.



Expand All @@ -66,7 +66,7 @@ Features

.. div:: sd-font-normal

``BrainState`` supports program `functionality augmentation <./apis/augment.html>`__ (such batching) with its `graph-based <./apis/graph.html>`__ Python objects.
``BrainState`` supports program `functionality augmentation <./apis/augment.rst>`__ (such batching) with its `graph-based <./apis/graph.rst>`__ Python objects.



Expand Down Expand Up @@ -112,7 +112,8 @@ We are building the `BDP ecosystem <https://ecosystem-for-brain-dynamics.readthe
:maxdepth: 1
:caption: Quickstart

quickstart/concepts.ipynb
quickstart/concepts-en.ipynb
quickstart/concepts-zh.ipynb
quickstart/ann_training.ipynb
quickstart/snn_training.ipynb
quickstart/snn_training.ipynb
Expand All @@ -124,6 +125,7 @@ We are building the `BDP ecosystem <https://ecosystem-for-brain-dynamics.readthe
:maxdepth: 2
:caption: Tutorials

tutorials/pygraph-zh.ipynb
tutorials/random_numbers.ipynb
tutorials/event_driven_computation.ipynb
tutorials/gspmd.ipynb
Expand Down
55 changes: 55 additions & 0 deletions docs/quickstart/concepts-en.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
{
"cells": [
{
"cell_type": "markdown",
"source": [
"# Key Concepts\n",
"\n",
"This section provides a brief introduction to some of the key concepts of the ``BrainState`` framework."
],
"metadata": {
"collapsed": false
},
"id": "2880d52052c4a9d4"
},
{
"cell_type": "markdown",
"source": [
"``BrainState`` is a high-performance computing framework dedicated to brain dynamics modelling, built on top of [JAX](https://github.com/jax-ml/jax). It provides neuroscience researchers, computational neuroscientists, and brain-like computing partitioner with a complete toolchain for building, optimising, and deploying neural network models of all kinds. It integrates modern hardware acceleration, automatic differentiation, event-driven computation, and other advanced features designed for neural networks, especially Spiking Neural Networks (SNNs). The following tutorials will introduce its core features and usage scenarios to help you get started and understand how to build and optimise brain dynamics models with BrainState. "
],
"metadata": {
"collapsed": false
},
"id": "3d60555bec9e4eab"
},
{
"cell_type": "markdown",
"source": [],
"metadata": {
"collapsed": false
},
"id": "f1a8630b34c771ae"
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Original file line number Diff line number Diff line change
Expand Up @@ -907,16 +907,6 @@
},
"id": "5f0d1171923e73e2"
},
{
"cell_type": "markdown",
"source": [
"# Key Concepts"
],
"metadata": {
"collapsed": false
},
"id": "2880d52052c4a9d4"
},
{
"cell_type": "markdown",
"source": [],
Expand Down
Loading

0 comments on commit 6d63d95

Please sign in to comment.