From 6967505b70281464dfe19e5090a8d9efa8c8d888 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Fri, 14 Jun 2024 19:58:05 +0800 Subject: [PATCH] fix math bugs --- brainstate/_module_test.py | 1 - brainstate/environ.py | 1 - brainstate/functional/__init__.py | 4 +-- brainstate/functional/_spikes.py | 1 - brainstate/nn/_elementwise.py | 9 +++--- brainstate/nn/_misc.py | 7 ++-- brainstate/nn/_others.py | 5 +-- brainstate/nn/_poolings.py | 41 ++++++++++++------------ brainstate/optim/__init__.py | 1 - brainstate/optim/_sgd_optimizer.py | 35 ++++++++++---------- brainstate/transform/__init__.py | 5 ++- brainstate/transform/_autograd.py | 2 +- brainstate/transform/_autograd_test.py | 2 -- brainstate/transform/_jit_test.py | 3 -- brainstate/transform/_make_jaxpr.py | 1 - brainstate/transform/_make_jaxpr_test.py | 2 -- brainstate/transform/_progress_bar.py | 4 +-- 17 files changed, 57 insertions(+), 67 deletions(-) diff --git a/brainstate/_module_test.py b/brainstate/_module_test.py index 305cca7..392da9d 100644 --- a/brainstate/_module_test.py +++ b/brainstate/_module_test.py @@ -130,4 +130,3 @@ def __init__(self): print(b.states()) print(b.states(level=0)) print(b.states(level=0)) - diff --git a/brainstate/environ.py b/brainstate/environ.py index 60f011e..6740ee3 100644 --- a/brainstate/environ.py +++ b/brainstate/environ.py @@ -24,7 +24,6 @@ 'dftype', 'ditype', 'dutype', 'dctype', ] - # Default, there are several shared arguments in the global context. I = 'i' # the index of the current computation. T = 't' # the current time of the current computation. diff --git a/brainstate/functional/__init__.py b/brainstate/functional/__init__.py index cb0516e..5ef4162 100644 --- a/brainstate/functional/__init__.py +++ b/brainstate/functional/__init__.py @@ -18,9 +18,9 @@ from ._activations import __all__ as __activations_all__ from ._normalization import * from ._normalization import __all__ as __others_all__ -from ._spikes import * -from ._spikes import __all__ as __spikes_all__ from ._others import * from ._others import __all__ as __others_all__ +from ._spikes import * +from ._spikes import __all__ as __spikes_all__ __all__ = __spikes_all__ + __others_all__ + __activations_all__ + __others_all__ diff --git a/brainstate/functional/_spikes.py b/brainstate/functional/_spikes.py index 89b7c2a..c6b342c 100644 --- a/brainstate/functional/_spikes.py +++ b/brainstate/functional/_spikes.py @@ -87,4 +87,3 @@ def spike_bitwise(x, y, op: str): return spike_bitwise_ixor(x, y) else: raise NotImplementedError(f"Unsupported bitwise operation: {op}.") - diff --git a/brainstate/nn/_elementwise.py b/brainstate/nn/_elementwise.py index eac8356..638f49a 100644 --- a/brainstate/nn/_elementwise.py +++ b/brainstate/nn/_elementwise.py @@ -19,11 +19,12 @@ from typing import Optional +import brainunit as bu import jax.numpy as jnp import jax.typing from ._base import ElementWiseBlock -from .. import math, environ, random, functional as F +from .. import environ, random, functional as F from .._module import Module from .._state import ParamState from ..mixin import Mode @@ -82,7 +83,7 @@ def __init__(self, threshold: float, value: float) -> None: self.value = value def __call__(self, x: ArrayLike) -> ArrayLike: - dtype = math.get_dtype(x) + dtype = bu.math.get_dtype(x) return jnp.where(x > jnp.asarray(self.threshold, dtype=dtype), x, jnp.asarray(self.value, dtype=dtype)) @@ -1142,7 +1143,7 @@ def __init__( self.prob = prob def __call__(self, x): - dtype = math.get_dtype(x) + dtype = bu.math.get_dtype(x) fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.') if fit_phase: keep_mask = random.bernoulli(self.prob, x.shape) @@ -1172,7 +1173,7 @@ def __init__( self.channel_axis = channel_axis def __call__(self, x): - dtype = math.get_dtype(x) + dtype = bu.math.get_dtype(x) # get fit phase fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.') diff --git a/brainstate/nn/_misc.py b/brainstate/nn/_misc.py index 3e9b8cf..989efcc 100644 --- a/brainstate/nn/_misc.py +++ b/brainstate/nn/_misc.py @@ -20,9 +20,10 @@ from functools import wraps from typing import Sequence, Callable +import brainunit as bu import jax.numpy as jnp -from .. import environ, math +from .. import environ from .._state import State from ..transform import vector_grad @@ -96,7 +97,7 @@ def integral(*args, **kwargs): ) dt = environ.get('dt') linear, derivative = vector_grad(fun, argnums=0, return_value=True)(*args, **kwargs) - phi = math.exprel(dt * linear) + phi = bu.math.exprel(dt * linear) return args[0] + dt * phi * derivative return integral @@ -128,5 +129,5 @@ def exp_euler_step(fun: Callable, *args, **kwargs): ) dt = environ.get('dt') linear, derivative = vector_grad(fun, argnums=0, return_value=True)(*args, **kwargs) - phi = math.exprel(dt * linear) + phi = bu.math.exprel(dt * linear) return args[0] + dt * phi * derivative diff --git a/brainstate/nn/_others.py b/brainstate/nn/_others.py index 3015f8c..1e46917 100644 --- a/brainstate/nn/_others.py +++ b/brainstate/nn/_others.py @@ -19,10 +19,11 @@ from functools import partial from typing import Optional +import brainunit as bu import jax.numpy as jnp from ._base import DnnLayer -from .. import random, math, environ, typing, init +from .. import random, environ, typing, init from ..mixin import Mode __all__ = [ @@ -88,7 +89,7 @@ def init_state(self, batch_size=None, **kwargs): self.mask = init.param(partial(random.bernoulli, self.prob), self.in_size, batch_size) def update(self, x): - dtype = math.get_dtype(x) + dtype = bu.math.get_dtype(x) fit_phase = environ.get('fit', desc='Whether this is a fitting process. Bool.') if fit_phase: assert self.mask.shape == x.shape, (f"Input shape {x.shape} does not match the mask shape {self.mask.shape}. " diff --git a/brainstate/nn/_poolings.py b/brainstate/nn/_poolings.py index a84de6d..9547026 100644 --- a/brainstate/nn/_poolings.py +++ b/brainstate/nn/_poolings.py @@ -21,12 +21,13 @@ from typing import Sequence, Optional from typing import Union, Tuple, Callable, List +import brainunit as bu import jax import jax.numpy as jnp import numpy as np from ._base import DnnLayer, ExplicitInOutSize -from .. import environ, math +from .. import environ from ..mixin import Mode from ..typing import Size @@ -53,8 +54,8 @@ class Flatten(DnnLayer, ExplicitInOutSize): Args: in_size: Sequence of int. The shape of the input tensor. - start_dim: first dim to flatten (default = 1). - end_dim: last dim to flatten (default = -1). + start_axis: first dim to flatten (default = 1). + end_axis: last dim to flatten (default = -1). Examples:: >>> import brainstate as bst @@ -74,36 +75,36 @@ class Flatten(DnnLayer, ExplicitInOutSize): def __init__( self, - start_dim: int = 0, - end_dim: int = -1, + start_axis: int = 0, + end_axis: int = -1, in_size: Optional[Size] = None ) -> None: super().__init__() - self.start_dim = start_dim - self.end_dim = end_dim + self.start_axis = start_axis + self.end_axis = end_axis if in_size is not None: self.in_size = tuple(in_size) - y = jax.eval_shape(functools.partial(math.flatten, start_dim=start_dim, end_dim=end_dim), + y = jax.eval_shape(functools.partial(bu.math.flatten, start_axis=start_axis, end_axis=end_axis), jax.ShapeDtypeStruct(self.in_size, environ.dftype())) self.out_size = y.shape def update(self, x): if self._in_size is None: - start_dim = self.start_dim if self.start_dim >= 0 else x.ndim + self.start_dim + start_axis = self.start_axis if self.start_axis >= 0 else x.ndim + self.start_axis else: assert x.ndim >= len(self.in_size), 'Input tensor has fewer dimensions than the expected shape.' dim_diff = x.ndim - len(self.in_size) if self.in_size != x.shape[dim_diff:]: raise ValueError(f'Input tensor has shape {x.shape}, but expected shape {self.in_size}.') - if self.start_dim >= 0: - start_dim = self.start_dim + dim_diff + if self.start_axis >= 0: + start_axis = self.start_axis + dim_diff else: - start_dim = x.ndim + self.start_dim - return math.flatten(x, start_dim, self.end_dim) + start_axis = x.ndim + self.start_axis + return bu.math.flatten(x, start_axis, self.end_axis) def __repr__(self) -> str: - return f'{self.__class__.__name__}(start_dim={self.start_dim}, end_dim={self.end_dim})' + return f'{self.__class__.__name__}(start_axis={self.start_axis}, end_axis={self.end_axis})' class Unflatten(DnnLayer, ExplicitInOutSize): @@ -124,7 +125,7 @@ class Unflatten(DnnLayer, ExplicitInOutSize): :math:`\prod_{i=1}^n U_i = S_{\text{dim}}`. Args: - dim: int, Dimension to be unflattened. + axis: int, Dimension to be unflattened. sizes: Sequence of int. New shape of the unflattened dimension. in_size: Sequence of int. The shape of the input tensor. """ @@ -132,7 +133,7 @@ class Unflatten(DnnLayer, ExplicitInOutSize): def __init__( self, - dim: int, + axis: int, sizes: Size, mode: Mode = None, name: str = None, @@ -140,7 +141,7 @@ def __init__( ) -> None: super().__init__(mode=mode, name=name) - self.dim = dim + self.axis = axis self.sizes = sizes if isinstance(sizes, (tuple, list)): for idx, elem in enumerate(sizes): @@ -152,15 +153,15 @@ def __init__( if in_size is not None: self.in_size = tuple(in_size) - y = jax.eval_shape(functools.partial(math.unflatten, dim=dim, sizes=sizes), + y = jax.eval_shape(functools.partial(bu.math.unflatten, axis=axis, sizes=sizes), jax.ShapeDtypeStruct(self.in_size, environ.dftype())) self.out_size = y.shape def update(self, x): - return math.unflatten(x, self.dim, self.sizes) + return bu.math.unflatten(x, self.axis, self.sizes) def __repr__(self): - return f'{self.__class__.__name__}(dim={self.dim}, sizes={self.sizes})' + return f'{self.__class__.__name__}(axis={self.axis}, sizes={self.sizes})' class _MaxPool(DnnLayer, ExplicitInOutSize): diff --git a/brainstate/optim/__init__.py b/brainstate/optim/__init__.py index 3dac80a..9c202d5 100644 --- a/brainstate/optim/__init__.py +++ b/brainstate/optim/__init__.py @@ -20,4 +20,3 @@ from ._sgd_optimizer import __all__ as optimizer_all __all__ = scheduler_all + optimizer_all - diff --git a/brainstate/optim/_sgd_optimizer.py b/brainstate/optim/_sgd_optimizer.py index 79ae119..d9e7117 100644 --- a/brainstate/optim/_sgd_optimizer.py +++ b/brainstate/optim/_sgd_optimizer.py @@ -18,11 +18,12 @@ import functools from typing import Union, Dict, Optional, Tuple, Any, TypeVar +import brainunit as bu import jax import jax.numpy as jnp from ._lr_scheduler import make_schedule, LearningRateScheduler -from .. import environ, math +from .. import environ from .._module import Module from .._state import State, LongTermState, StateDictManager, visible_state_dict @@ -282,7 +283,7 @@ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = for k, v in train_states.items(): assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.' self.weight_states.add_unique_elem(k, v) - self.momentum_states[k] = OptimState(math.tree_zeros_like(v.value)) + self.momentum_states[k] = OptimState(bu.math.tree_zeros_like(v.value)) def update(self, grads: dict): lr = self.lr() @@ -349,7 +350,7 @@ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = for k, v in train_states.items(): assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.' self.weight_states.add_unique_elem(k, v) - self.momentum_states[k] = OptimState(math.tree_zeros_like(v.value)) + self.momentum_states[k] = OptimState(bu.math.tree_zeros_like(v.value)) def update(self, grads: dict): lr = self.lr() @@ -417,7 +418,7 @@ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = for k, v in train_states.items(): assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.' self.weight_states.add_unique_elem(k, v) - self.cache_states[k] = OptimState(math.tree_zeros_like(v.value)) + self.cache_states[k] = OptimState(bu.math.tree_zeros_like(v.value)) def update(self, grads: dict): lr = self.lr() @@ -500,8 +501,8 @@ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = for k, v in train_states.items(): assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.' self.weight_states.add_unique_elem(k, v) - self.cache_states[k] = OptimState(math.tree_zeros_like(v.value)) - self.delta_states[k] = OptimState(math.tree_zeros_like(v.value)) + self.cache_states[k] = OptimState(bu.math.tree_zeros_like(v.value)) + self.delta_states[k] = OptimState(bu.math.tree_zeros_like(v.value)) def update(self, grads: dict): weight_values, grad_values, cache_values, delta_values = to_same_dict_tree( @@ -574,7 +575,7 @@ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = for k, v in train_states.items(): assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.' self.weight_states.add_unique_elem(k, v) - self.cache_states[k] = OptimState(math.tree_zeros_like(v.value)) + self.cache_states[k] = OptimState(bu.math.tree_zeros_like(v.value)) def update(self, grads: dict): lr = self.lr() @@ -647,8 +648,8 @@ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = for k, v in train_states.items(): assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.' self.weight_states.add_unique_elem(k, v) - self.m1_states[k] = OptimState(math.tree_zeros_like(v.value)) - self.m2_states[k] = OptimState(math.tree_zeros_like(v.value)) + self.m1_states[k] = OptimState(bu.math.tree_zeros_like(v.value)) + self.m2_states[k] = OptimState(bu.math.tree_zeros_like(v.value)) def update(self, grads: dict): lr = self.lr() @@ -730,7 +731,7 @@ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = for k, v in train_states.items(): assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.' self.weight_states.add_unique_elem(k, v) - self.momentum_states[k] = OptimState(math.tree_zeros_like(v.value)) + self.momentum_states[k] = OptimState(bu.math.tree_zeros_like(v.value)) def update(self, grads: dict): lr = self.lr() @@ -835,10 +836,10 @@ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = for k, v in train_states.items(): assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.' self.weight_states.add_unique_elem(k, v) - self.exp_avg_states[k] = OptimState(math.tree_zeros_like(v.value)) - self.exp_avg_sq_states[k] = OptimState(math.tree_zeros_like(v.value)) - self.exp_avg_diff_states[k] = OptimState(math.tree_zeros_like(v.value)) - self.pre_grad_states[k] = OptimState(math.tree_zeros_like(v.value)) + self.exp_avg_states[k] = OptimState(bu.math.tree_zeros_like(v.value)) + self.exp_avg_sq_states[k] = OptimState(bu.math.tree_zeros_like(v.value)) + self.exp_avg_diff_states[k] = OptimState(bu.math.tree_zeros_like(v.value)) + self.pre_grad_states[k] = OptimState(bu.math.tree_zeros_like(v.value)) def update(self, grads: dict): lr = self.lr() @@ -989,10 +990,10 @@ def register_trainable_weights(self, train_states: Optional[Dict[str, State]] = for k, v in train_states.items(): assert isinstance(v, State), f'"{k}" must be an instance of brainstate.State.' self.weight_states.add_unique_elem(k, v) - self.m1_states[k] = OptimState(math.tree_zeros_like(v.value)) - self.m2_states[k] = OptimState(math.tree_zeros_like(v.value)) + self.m1_states[k] = OptimState(bu.math.tree_zeros_like(v.value)) + self.m2_states[k] = OptimState(bu.math.tree_zeros_like(v.value)) if self.amsgrad: - self.vmax_states[k] = OptimState(math.tree_zeros_like(v.value)) + self.vmax_states[k] = OptimState(bu.math.tree_zeros_like(v.value)) def update(self, grads: dict): lr_old = self.lr() diff --git a/brainstate/transform/__init__.py b/brainstate/transform/__init__.py index aecb827..9197f96 100644 --- a/brainstate/transform/__init__.py +++ b/brainstate/transform/__init__.py @@ -17,10 +17,10 @@ This module contains the functions for the transformation of the brain data. """ -from ._control import * -from ._control import __all__ as _controls_all from ._autograd import * from ._autograd import __all__ as _gradients_all +from ._control import * +from ._control import __all__ as _controls_all from ._jit import * from ._jit import __all__ as _jit_all from ._jit_error import * @@ -33,4 +33,3 @@ __all__ = _gradients_all + _jit_error_all + _controls_all + _make_jaxpr_all + _jit_all + _progress_bar_all del _gradients_all, _jit_error_all, _controls_all, _make_jaxpr_all, _jit_all, _progress_bar_all - diff --git a/brainstate/transform/_autograd.py b/brainstate/transform/_autograd.py index b4e6173..c74f76c 100644 --- a/brainstate/transform/_autograd.py +++ b/brainstate/transform/_autograd.py @@ -25,8 +25,8 @@ from jax.api_util import argnums_partial from jax.extend import linear_util -from brainstate._utils import set_module_as from brainstate._state import State, StateTrace, StateDictManager +from brainstate._utils import set_module_as __all__ = [ 'vector_grad', 'grad', 'jacrev', 'jacfwd', 'jacobian', 'hessian', diff --git a/brainstate/transform/_autograd_test.py b/brainstate/transform/_autograd_test.py index 2a29013..79070d3 100644 --- a/brainstate/transform/_autograd_test.py +++ b/brainstate/transform/_autograd_test.py @@ -537,7 +537,6 @@ def f1(x, y): def test_jacrev_return_aux1(self): with bc.environ.context(precision=64): - def f1(x, y): a = 4 * x[1] ** 2 - 2 * x[2] r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], a, x[2] * jnp.sin(x[0])]) @@ -564,7 +563,6 @@ def f1(x, y): assert (vec == _r).all() - class TestClassFuncJacobian(unittest.TestCase): def test_jacrev1(self): def f1(x, y): diff --git a/brainstate/transform/_jit_test.py b/brainstate/transform/_jit_test.py index 02e3df7..a435d34 100644 --- a/brainstate/transform/_jit_test.py +++ b/brainstate/transform/_jit_test.py @@ -16,7 +16,6 @@ import unittest import jax.numpy as jnp -import jax.stages import brainstate as bc @@ -90,7 +89,6 @@ def log2(x): self.assertTrue(len(compiling) == 2) def test_jit_attribute_origin_fun(self): - def fun1(x): return x @@ -99,4 +97,3 @@ def fun1(x): self.assertTrue(isinstance(jitted_fun.stateful_fun, bc.transform.StatefulFunction)) self.assertTrue(callable(jitted_fun.jitted_fun)) self.assertTrue(callable(jitted_fun.clear_cache)) - diff --git a/brainstate/transform/_make_jaxpr.py b/brainstate/transform/_make_jaxpr.py index 8fffe5e..ba1c78a 100644 --- a/brainstate/transform/_make_jaxpr.py +++ b/brainstate/transform/_make_jaxpr.py @@ -75,7 +75,6 @@ PyTree = Any AxisName = Hashable - __all__ = [ "StatefulFunction", "make_jaxpr", diff --git a/brainstate/transform/_make_jaxpr_test.py b/brainstate/transform/_make_jaxpr_test.py index b14891e..d6c9721 100644 --- a/brainstate/transform/_make_jaxpr_test.py +++ b/brainstate/transform/_make_jaxpr_test.py @@ -129,5 +129,3 @@ def f(): with pytest.raises(ValueError): f() - - diff --git a/brainstate/transform/_progress_bar.py b/brainstate/transform/_progress_bar.py index 71a46fe..4186b74 100644 --- a/brainstate/transform/_progress_bar.py +++ b/brainstate/transform/_progress_bar.py @@ -14,13 +14,12 @@ # ============================================================================== from __future__ import annotations + import copy from typing import Optional import jax -from brainstate import environ - try: from tqdm.auto import tqdm except (ImportError, ModuleNotFoundError): @@ -95,7 +94,6 @@ def _close_tqdm(self): self.tqdm_bars[0].close() def __call__(self, iter_num, *args, **kwargs): - _ = jax.lax.cond( iter_num == 0, lambda: jax.debug.callback(self._define_tqdm),