Skip to content

Commit

Permalink
Merge pull request #344 from chaoming0625/master
Browse files Browse the repository at this point in the history
The update and fix of functions and models
  • Loading branch information
chaoming0625 authored Mar 12, 2023
2 parents 0a519c0 + d2bd305 commit 454969b
Show file tree
Hide file tree
Showing 19 changed files with 801 additions and 185 deletions.
11 changes: 7 additions & 4 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.3.6"
__version__ = "2.3.7"


# fundamental supporting modules
Expand Down Expand Up @@ -61,20 +61,23 @@
experimental,
)
from brainpy._src.dyn.base import not_pass_shared
from brainpy._src.dyn.base import (DynamicalSystem,
DynamicalSystemNS,
from brainpy._src.dyn.base import (DynamicalSystem as DynamicalSystem,
Container as Container,
Sequential as Sequential,
Network as Network,
NeuGroup as NeuGroup,
NeuGroupNS as NeuGroupNS,
SynConn as SynConn,
SynOut as SynOut,
SynSTP as SynSTP,
SynLTP as SynLTP,
TwoEndConn as TwoEndConn,
CondNeuGroup as CondNeuGroup,
Channel as Channel)
from brainpy._src.dyn.base import (DynamicalSystemNS as DynamicalSystemNS,
NeuGroupNS as NeuGroupNS)
from brainpy._src.dyn.synapses_v2.base import (SynOutNS as SynOutNS,
SynSTPNS as SynSTPNS,
SynConnNS as SynConnNS, )
from brainpy._src.dyn.transform import (LoopOverTime as LoopOverTime,)
from brainpy._src.dyn.runners import (DSRunner as DSRunner) # runner
from brainpy._src.dyn.context import share, Delay
Expand Down
14 changes: 14 additions & 0 deletions brainpy/_src/dyn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,20 @@ def __del__(self):
def clear_input(self):
pass

def __rrshift__(self, other):
"""Support using right shift operator to call modules.
Examples
--------
>>> import brainpy as bp
>>> x = bp.math.random.rand((10, 10))
>>> l = bp.layers.Activation('tanh')
>>> y = x >> l
"""
return self.__call__(other)


class DynamicalSystemNS(DynamicalSystem):
"""Dynamical system without the need of shared parameters passing into ``update()`` function."""
Expand Down
28 changes: 16 additions & 12 deletions brainpy/_src/dyn/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,16 +130,18 @@ def update(self, x):
x = bm.as_jax(x)

if share.load('fit'):
mean = jnp.mean(x, self.axis)
mean_of_square = jnp.mean(_square(x), self.axis)
if self.axis_name is not None:
mean, mean_of_square = jnp.split(lax.pmean(jnp.concatenate([mean, mean_of_square]),
axis_name=self.axis_name,
axis_index_groups=self.axis_index_groups),
2)
var = jnp.maximum(0., mean_of_square - _square(mean))
self.running_mean.value = (self.momentum * self.running_mean + (1 - self.momentum) * mean)
self.running_var.value = (self.momentum * self.running_var + (1 - self.momentum) * var)
mean = jnp.mean(x, self.axis)
mean_of_square = jnp.mean(_square(x), self.axis)
if self.axis_name is not None:
mean, mean_of_square = jnp.split(
lax.pmean(jnp.concatenate([mean, mean_of_square]),
axis_name=self.axis_name,
axis_index_groups=self.axis_index_groups),
2
)
var = jnp.maximum(0., mean_of_square - _square(mean))
self.running_mean.value = (self.momentum * self.running_mean + (1 - self.momentum) * mean)
self.running_var.value = (self.momentum * self.running_var + (1 - self.momentum) * var)
else:
mean = self.running_mean.value
var = self.running_var.value
Expand Down Expand Up @@ -488,7 +490,7 @@ def __init__(
self.bias = bm.TrainVar(parameter(self.bias_initializer, self.normalized_shape))
self.scale = bm.TrainVar(parameter(self.scale_initializer, self.normalized_shape))

def update(self,x):
def update(self, x):
if x.shape[-len(self.normalized_shape):] != self.normalized_shape:
raise ValueError(f'Expect the input shape should be (..., {", ".join(self.normalized_shape)}), '
f'but we got {x.shape}')
Expand Down Expand Up @@ -629,6 +631,8 @@ def __init__(
scale_initializer=scale_initializer,
mode=mode,
name=name)


BatchNorm1D = BatchNorm1d
BatchNorm2D = BatchNorm2d
BatchNorm3D = BatchNorm3d
BatchNorm3D = BatchNorm3d
7 changes: 6 additions & 1 deletion brainpy/_src/dyn/neurons/biological_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
from brainpy import check
from brainpy._src.dyn.base import NeuGroupNS
from brainpy._src.dyn.context import share
from brainpy._src.initialize import OneInit, Uniform, Initializer, parameter, noise as init_noise, variable_
from brainpy._src.initialize import (OneInit,
Uniform,
Initializer,
parameter,
noise as init_noise,
variable_)
from brainpy._src.integrators.joint_eq import JointEq
from brainpy._src.integrators.ode.generic import odeint
from brainpy._src.integrators.sde.generic import sdeint
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/dyn/neurons/input_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax.numpy as jnp
from brainpy._src.dyn.context import share
import brainpy.math as bm
from brainpy._src.dyn.base import NeuGroupNS, not_pass_shared
from brainpy._src.dyn.base import NeuGroupNS
from brainpy._src.initialize import Initializer, parameter, variable_
from brainpy.types import Shape, ArrayType

Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/dyn/neurons/noise_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax.numpy as jnp
from brainpy._src.dyn.context import share
from brainpy import math as bm, initialize as init
from brainpy._src.dyn.base import NeuGroupNS as NeuGroup, not_pass_shared
from brainpy._src.dyn.base import NeuGroupNS as NeuGroup
from brainpy._src.initialize import Initializer
from brainpy._src.integrators.sde.generic import sdeint
from brainpy.types import ArrayType, Shape
Expand Down
Loading

0 comments on commit 454969b

Please sign in to comment.