Skip to content

Commit

Permalink
fix math bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jun 14, 2024
1 parent 6e76926 commit 6967505
Show file tree
Hide file tree
Showing 17 changed files with 57 additions and 67 deletions.
1 change: 0 additions & 1 deletion brainstate/_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,3 @@ def __init__(self):
print(b.states())
print(b.states(level=0))
print(b.states(level=0))

1 change: 0 additions & 1 deletion brainstate/environ.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
'dftype', 'ditype', 'dutype', 'dctype',
]


# Default, there are several shared arguments in the global context.
I = 'i' # the index of the current computation.
T = 't' # the current time of the current computation.
Expand Down
4 changes: 2 additions & 2 deletions brainstate/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
from ._activations import __all__ as __activations_all__
from ._normalization import *
from ._normalization import __all__ as __others_all__
from ._spikes import *
from ._spikes import __all__ as __spikes_all__
from ._others import *
from ._others import __all__ as __others_all__
from ._spikes import *
from ._spikes import __all__ as __spikes_all__

__all__ = __spikes_all__ + __others_all__ + __activations_all__ + __others_all__
1 change: 0 additions & 1 deletion brainstate/functional/_spikes.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,3 @@ def spike_bitwise(x, y, op: str):
return spike_bitwise_ixor(x, y)
else:
raise NotImplementedError(f"Unsupported bitwise operation: {op}.")

9 changes: 5 additions & 4 deletions brainstate/nn/_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@

from typing import Optional

import brainunit as bu
import jax.numpy as jnp
import jax.typing

from ._base import ElementWiseBlock
from .. import math, environ, random, functional as F
from .. import environ, random, functional as F
from .._module import Module
from .._state import ParamState
from ..mixin import Mode
Expand Down Expand Up @@ -82,7 +83,7 @@ def __init__(self, threshold: float, value: float) -> None:
self.value = value

def __call__(self, x: ArrayLike) -> ArrayLike:
dtype = math.get_dtype(x)
dtype = bu.math.get_dtype(x)
return jnp.where(x > jnp.asarray(self.threshold, dtype=dtype),
x,
jnp.asarray(self.value, dtype=dtype))
Expand Down Expand Up @@ -1142,7 +1143,7 @@ def __init__(
self.prob = prob

def __call__(self, x):
dtype = math.get_dtype(x)
dtype = bu.math.get_dtype(x)
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
if fit_phase:
keep_mask = random.bernoulli(self.prob, x.shape)
Expand Down Expand Up @@ -1172,7 +1173,7 @@ def __init__(
self.channel_axis = channel_axis

def __call__(self, x):
dtype = math.get_dtype(x)
dtype = bu.math.get_dtype(x)
# get fit phase
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')

Expand Down
7 changes: 4 additions & 3 deletions brainstate/nn/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
from functools import wraps
from typing import Sequence, Callable

import brainunit as bu
import jax.numpy as jnp

from .. import environ, math
from .. import environ
from .._state import State
from ..transform import vector_grad

Expand Down Expand Up @@ -96,7 +97,7 @@ def integral(*args, **kwargs):
)
dt = environ.get('dt')
linear, derivative = vector_grad(fun, argnums=0, return_value=True)(*args, **kwargs)
phi = math.exprel(dt * linear)
phi = bu.math.exprel(dt * linear)
return args[0] + dt * phi * derivative

return integral
Expand Down Expand Up @@ -128,5 +129,5 @@ def exp_euler_step(fun: Callable, *args, **kwargs):
)
dt = environ.get('dt')
linear, derivative = vector_grad(fun, argnums=0, return_value=True)(*args, **kwargs)
phi = math.exprel(dt * linear)
phi = bu.math.exprel(dt * linear)
return args[0] + dt * phi * derivative
5 changes: 3 additions & 2 deletions brainstate/nn/_others.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
from functools import partial
from typing import Optional

import brainunit as bu
import jax.numpy as jnp

from ._base import DnnLayer
from .. import random, math, environ, typing, init
from .. import random, environ, typing, init
from ..mixin import Mode

__all__ = [
Expand Down Expand Up @@ -88,7 +89,7 @@ def init_state(self, batch_size=None, **kwargs):
self.mask = init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size)

def update(self, x):
dtype = math.get_dtype(x)
dtype = bu.math.get_dtype(x)
fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.')
if fit_phase:
assert self.mask.shape == x.shape, (f"Input shape {x.shape} does not match the mask shape {self.mask.shape}. "
Expand Down
41 changes: 21 additions & 20 deletions brainstate/nn/_poolings.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@
from typing import Sequence, Optional
from typing import Union, Tuple, Callable, List

import brainunit as bu
import jax
import jax.numpy as jnp
import numpy as np

from ._base import DnnLayer, ExplicitInOutSize
from .. import environ, math
from .. import environ
from ..mixin import Mode
from ..typing import Size

Expand All @@ -53,8 +54,8 @@ class Flatten(DnnLayer, ExplicitInOutSize):
Args:
in_size: Sequence of int. The shape of the input tensor.
start_dim: first dim to flatten (default = 1).
end_dim: last dim to flatten (default = -1).
start_axis: first dim to flatten (default = 1).
end_axis: last dim to flatten (default = -1).
Examples::
>>> import brainstate as bst
Expand All @@ -74,36 +75,36 @@ class Flatten(DnnLayer, ExplicitInOutSize):

def __init__(
self,
start_dim: int = 0,
end_dim: int = -1,
start_axis: int = 0,
end_axis: int = -1,
in_size: Optional[Size] = None
) -> None:
super().__init__()
self.start_dim = start_dim
self.end_dim = end_dim
self.start_axis = start_axis
self.end_axis = end_axis

if in_size is not None:
self.in_size = tuple(in_size)
y = jax.eval_shape(functools.partial(math.flatten, start_dim=start_dim, end_dim=end_dim),
y = jax.eval_shape(functools.partial(bu.math.flatten, start_axis=start_axis, end_axis=end_axis),
jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
self.out_size = y.shape

def update(self, x):
if self._in_size is None:
start_dim = self.start_dim if self.start_dim >= 0 else x.ndim + self.start_dim
start_axis = self.start_axis if self.start_axis >= 0 else x.ndim + self.start_axis
else:
assert x.ndim >= len(self.in_size), 'Input tensor has fewer dimensions than the expected shape.'
dim_diff = x.ndim - len(self.in_size)
if self.in_size != x.shape[dim_diff:]:
raise ValueError(f'Input tensor has shape {x.shape}, but expected shape {self.in_size}.')
if self.start_dim >= 0:
start_dim = self.start_dim + dim_diff
if self.start_axis >= 0:
start_axis = self.start_axis + dim_diff
else:
start_dim = x.ndim + self.start_dim
return math.flatten(x, start_dim, self.end_dim)
start_axis = x.ndim + self.start_axis
return bu.math.flatten(x, start_axis, self.end_axis)

def __repr__(self) -> str:
return f'{self.__class__.__name__}(start_dim={self.start_dim}, end_dim={self.end_dim})'
return f'{self.__class__.__name__}(start_axis={self.start_axis}, end_axis={self.end_axis})'


class Unflatten(DnnLayer, ExplicitInOutSize):
Expand All @@ -124,23 +125,23 @@ class Unflatten(DnnLayer, ExplicitInOutSize):
:math:`\prod_{i=1}^n U_i = S_{\text{dim}}`.
Args:
dim: int, Dimension to be unflattened.
axis: int, Dimension to be unflattened.
sizes: Sequence of int. New shape of the unflattened dimension.
in_size: Sequence of int. The shape of the input tensor.
"""
__module__ = 'brainstate.nn'

def __init__(
self,
dim: int,
axis: int,
sizes: Size,
mode: Mode = None,
name: str = None,
in_size: Optional[Size] = None
) -> None:
super().__init__(mode=mode, name=name)

self.dim = dim
self.axis = axis
self.sizes = sizes
if isinstance(sizes, (tuple, list)):
for idx, elem in enumerate(sizes):
Expand All @@ -152,15 +153,15 @@ def __init__(

if in_size is not None:
self.in_size = tuple(in_size)
y = jax.eval_shape(functools.partial(math.unflatten, dim=dim, sizes=sizes),
y = jax.eval_shape(functools.partial(bu.math.unflatten, axis=axis, sizes=sizes),
jax.ShapeDtypeStruct(self.in_size, environ.dftype()))
self.out_size = y.shape

def update(self, x):
return math.unflatten(x, self.dim, self.sizes)
return bu.math.unflatten(x, self.axis, self.sizes)

def __repr__(self):
return f'{self.__class__.__name__}(dim={self.dim}, sizes={self.sizes})'
return f'{self.__class__.__name__}(axis={self.axis}, sizes={self.sizes})'


class _MaxPool(DnnLayer, ExplicitInOutSize):
Expand Down
1 change: 0 additions & 1 deletion brainstate/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,3 @@
from ._sgd_optimizer import __all__ as optimizer_all

__all__ = scheduler_all + optimizer_all

35 changes: 18 additions & 17 deletions brainstate/optim/_sgd_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
import functools
from typing import Union, Dict, Optional, Tuple, Any, TypeVar

import brainunit as bu
import jax
import jax.numpy as jnp

from ._lr_scheduler import make_schedule, LearningRateScheduler
from .. import environ, math
from .. import environ
from .._module import Module
from .._state import State, LongTermState, StateDictManager, visible_state_dict

Expand Down Expand Up @@ -282,7 +283,7 @@ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] =
for k, v in train_states.items():
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
self.weight_states.add_unique_elem(k, v)
self.momentum_states[k] = OptimState(math.tree_zeros_like(v.value))
self.momentum_states[k] = OptimState(bu.math.tree_zeros_like(v.value))

def update(self, grads: dict):
lr = self.lr()
Expand Down Expand Up @@ -349,7 +350,7 @@ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] =
for k, v in train_states.items():
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
self.weight_states.add_unique_elem(k, v)
self.momentum_states[k] = OptimState(math.tree_zeros_like(v.value))
self.momentum_states[k] = OptimState(bu.math.tree_zeros_like(v.value))

def update(self, grads: dict):
lr = self.lr()
Expand Down Expand Up @@ -417,7 +418,7 @@ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] =
for k, v in train_states.items():
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
self.weight_states.add_unique_elem(k, v)
self.cache_states[k] = OptimState(math.tree_zeros_like(v.value))
self.cache_states[k] = OptimState(bu.math.tree_zeros_like(v.value))

def update(self, grads: dict):
lr = self.lr()
Expand Down Expand Up @@ -500,8 +501,8 @@ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] =
for k, v in train_states.items():
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
self.weight_states.add_unique_elem(k, v)
self.cache_states[k] = OptimState(math.tree_zeros_like(v.value))
self.delta_states[k] = OptimState(math.tree_zeros_like(v.value))
self.cache_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
self.delta_states[k] = OptimState(bu.math.tree_zeros_like(v.value))

def update(self, grads: dict):
weight_values, grad_values, cache_values, delta_values = to_same_dict_tree(
Expand Down Expand Up @@ -574,7 +575,7 @@ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] =
for k, v in train_states.items():
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
self.weight_states.add_unique_elem(k, v)
self.cache_states[k] = OptimState(math.tree_zeros_like(v.value))
self.cache_states[k] = OptimState(bu.math.tree_zeros_like(v.value))

def update(self, grads: dict):
lr = self.lr()
Expand Down Expand Up @@ -647,8 +648,8 @@ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] =
for k, v in train_states.items():
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
self.weight_states.add_unique_elem(k, v)
self.m1_states[k] = OptimState(math.tree_zeros_like(v.value))
self.m2_states[k] = OptimState(math.tree_zeros_like(v.value))
self.m1_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
self.m2_states[k] = OptimState(bu.math.tree_zeros_like(v.value))

def update(self, grads: dict):
lr = self.lr()
Expand Down Expand Up @@ -730,7 +731,7 @@ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] =
for k, v in train_states.items():
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
self.weight_states.add_unique_elem(k, v)
self.momentum_states[k] = OptimState(math.tree_zeros_like(v.value))
self.momentum_states[k] = OptimState(bu.math.tree_zeros_like(v.value))

def update(self, grads: dict):
lr = self.lr()
Expand Down Expand Up @@ -835,10 +836,10 @@ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] =
for k, v in train_states.items():
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
self.weight_states.add_unique_elem(k, v)
self.exp_avg_states[k] = OptimState(math.tree_zeros_like(v.value))
self.exp_avg_sq_states[k] = OptimState(math.tree_zeros_like(v.value))
self.exp_avg_diff_states[k] = OptimState(math.tree_zeros_like(v.value))
self.pre_grad_states[k] = OptimState(math.tree_zeros_like(v.value))
self.exp_avg_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
self.exp_avg_sq_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
self.exp_avg_diff_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
self.pre_grad_states[k] = OptimState(bu.math.tree_zeros_like(v.value))

def update(self, grads: dict):
lr = self.lr()
Expand Down Expand Up @@ -989,10 +990,10 @@ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] =
for k, v in train_states.items():
assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.'
self.weight_states.add_unique_elem(k, v)
self.m1_states[k] = OptimState(math.tree_zeros_like(v.value))
self.m2_states[k] = OptimState(math.tree_zeros_like(v.value))
self.m1_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
self.m2_states[k] = OptimState(bu.math.tree_zeros_like(v.value))
if self.amsgrad:
self.vmax_states[k] = OptimState(math.tree_zeros_like(v.value))
self.vmax_states[k] = OptimState(bu.math.tree_zeros_like(v.value))

def update(self, grads: dict):
lr_old = self.lr()
Expand Down
5 changes: 2 additions & 3 deletions brainstate/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
This module contains the functions for the transformation of the brain data.
"""

from ._control import *
from ._control import __all__ as _controls_all
from ._autograd import *
from ._autograd import __all__ as _gradients_all
from ._control import *
from ._control import __all__ as _controls_all
from ._jit import *
from ._jit import __all__ as _jit_all
from ._jit_error import *
Expand All @@ -33,4 +33,3 @@
__all__ = _gradients_all + _jit_error_all + _controls_all + _make_jaxpr_all + _jit_all + _progress_bar_all

del _gradients_all, _jit_error_all, _controls_all, _make_jaxpr_all, _jit_all, _progress_bar_all

2 changes: 1 addition & 1 deletion brainstate/transform/_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from jax.api_util import argnums_partial
from jax.extend import linear_util

from brainstate._utils import set_module_as
from brainstate._state import State, StateTrace, StateDictManager
from brainstate._utils import set_module_as

__all__ = [
'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian',
Expand Down
2 changes: 0 additions & 2 deletions brainstate/transform/_autograd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,6 @@ def f1(x, y):

def test_jacrev_return_aux1(self):
with bc.environ.context(precision=64):

def f1(x, y):
a = 4 * x[1] ** 2 - 2 * x[2]
r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], a, x[2] * jnp.sin(x[0])])
Expand All @@ -564,7 +563,6 @@ def f1(x, y):
assert (vec == _r).all()



class TestClassFuncJacobian(unittest.TestCase):
def test_jacrev1(self):
def f1(x, y):
Expand Down
Loading

0 comments on commit 6967505

Please sign in to comment.