From a61ea7427de883413a3f070b7287612da92b52f3 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Mon, 25 Nov 2024 16:39:47 +0800 Subject: [PATCH] Support physical unit-aware gradient computation using `brainunit.autograd` (#42) * csr benchmark * fix xla custom op bugs * update examples * fix memory access error when n_conn is small * fix hashable bug * update examples * fix bug * support physical unit-aware gradient using `brainunit.autograd` * update requirements * use `brainunit.linalg.dot` rather than `brainunit.math.dot` --- brainstate/_state.py | 2 +- brainstate/augment/_autograd.py | 226 +++++++++++++------------- brainstate/augment/_autograd_test.py | 97 +++++++++++ brainstate/event/_csr_benchmark.py | 14 ++ brainstate/nn/_interaction/_linear.py | 4 +- pyproject.toml | 2 +- requirements.txt | 2 +- setup.py | 2 +- 8 files changed, 229 insertions(+), 120 deletions(-) create mode 100644 brainstate/event/_csr_benchmark.py diff --git a/brainstate/_state.py b/brainstate/_state.py index 1f97893..73f26f9 100644 --- a/brainstate/_state.py +++ b/brainstate/_state.py @@ -679,7 +679,7 @@ def recovery_original_values(self) -> None: """ for st, val in zip(self.states, self._original_state_values): # internal use - st._value = val + st.restore_value(val) def merge(self, *traces) -> 'StateTraceStack': """ diff --git a/brainstate/augment/_autograd.py b/brainstate/augment/_autograd.py index 96eb049..5db7cd4 100644 --- a/brainstate/augment/_autograd.py +++ b/brainstate/augment/_autograd.py @@ -29,15 +29,11 @@ from __future__ import annotations -import inspect -from functools import partial, wraps +from functools import wraps, partial from typing import Union, Callable, Dict, Sequence, Optional, Any, Tuple, TypeVar, Iterator +import brainunit as u import jax -from jax import numpy as jnp -from jax._src.api import _vjp -from jax.api_util import argnums_partial -from jax.extend import linear_util from brainstate._state import State, StateTraceStack from brainstate._utils import set_module_as @@ -54,54 +50,15 @@ AuxData = PyTree -def _isgeneratorfunction(fun): - # re-implemented here because of https://bugs.python.org/issue33261 - while inspect.ismethod(fun): - fun = fun.__func__ - while isinstance(fun, partial): - fun = fun.func - return inspect.isfunction(fun) and bool(fun.__code__.co_flags & inspect.CO_GENERATOR) - - -def _check_callable(fun): - # In Python 3.10+, the only thing stopping us from supporting staticmethods - # is that we can't take weak references to them, which the C++ JIT requires. - if isinstance(fun, staticmethod): - raise TypeError(f"staticmethod arguments are not supported, got {fun}") - if not callable(fun): - raise TypeError(f"Expected a callable value, got {fun}") - if _isgeneratorfunction(fun): - raise TypeError(f"Expected a function, got a generator function: {fun}") - - -def functional_vector_grad(func, argnums=0, return_value: bool = False, has_aux: bool = False): - """ - Compute the gradient of a vector with respect to the input. - """ - _check_callable(func) - - @wraps(func) - def grad_fun(*args, **kwargs): - f = linear_util.wrap_init(func, kwargs) - f_partial, dyn_args = argnums_partial(f, argnums, args, require_static_args_hashable=False) - if has_aux: - y, vjp_fn, aux = _vjp(f_partial, *dyn_args, has_aux=True) - else: - y, vjp_fn = _vjp(f_partial, *dyn_args, has_aux=False) - leaves, tree = jax.tree.flatten(y) - tangents = jax.tree.unflatten(tree, [jnp.ones(l.shape, dtype=l.dtype) for l in leaves]) - grads = vjp_fn(tangents) - if isinstance(argnums, int): - grads = grads[0] - if has_aux: - return (grads, y, aux) if return_value else (grads, aux) - else: - return (grads, y) if return_value else grads - - return grad_fun - - -def _jacrev(fun, argnums=0, holomorphic=False, allow_int=False, has_aux=False, return_value=False): +def _jacrev( + fun, + argnums=0, + holomorphic=False, + allow_int=False, + has_aux=False, + return_value=False, + unit_aware=False, +): @wraps(fun) def fun_wrapped(*args, **kwargs): if has_aux: @@ -117,7 +74,18 @@ def fun_wrapped(*args, **kwargs): else: return y, None - transform = jax.jacrev(fun_wrapped, argnums=argnums, holomorphic=holomorphic, allow_int=allow_int, has_aux=True) + if unit_aware: + transform = u.autograd.jacrev(fun_wrapped, + argnums=argnums, + holomorphic=holomorphic, + allow_int=allow_int, + has_aux=True) + else: + transform = jax.jacrev(fun_wrapped, + argnums=argnums, + holomorphic=holomorphic, + allow_int=allow_int, + has_aux=True) @wraps(fun) def jacfun(*args, **kwargs): @@ -130,7 +98,14 @@ def jacfun(*args, **kwargs): return jacfun -def _jacfwd(fun, argnums=0, holomorphic=False, has_aux=False, return_value=False): +def _jacfwd( + fun, + argnums=0, + holomorphic=False, + has_aux=False, + return_value=False, + unit_aware=False, +): @wraps(fun) def fun_wrapped(*args, **kwargs): if has_aux: @@ -146,7 +121,16 @@ def fun_wrapped(*args, **kwargs): else: return y, None - transform = jax.jacfwd(fun_wrapped, argnums=argnums, holomorphic=holomorphic, has_aux=True) + if unit_aware: + transform = u.autograd.jacfwd(fun_wrapped, + argnums=argnums, + holomorphic=holomorphic, + has_aux=True) + else: + transform = jax.jacfwd(fun_wrapped, + argnums=argnums, + holomorphic=holomorphic, + has_aux=True) @wraps(fun) def jacfun(*args, **kwargs): @@ -323,9 +307,9 @@ def grad( argnums: Optional[Union[int, Sequence[int]]] = None, holomorphic: Optional[bool] = False, allow_int: Optional[bool] = False, - reduce_axes: Optional[Sequence[str]] = (), has_aux: Optional[bool] = None, return_value: Optional[bool] = False, + unit_aware: bool = False, ) -> GradientTransform | Callable[[Callable], GradientTransform]: """ Compute the gradient of a scalar-valued function with respect to its arguments. @@ -333,27 +317,24 @@ def grad( %s Args: - fun: callable. the scalar-valued function to be differentiated. - reduce_axes: (Sequence[str]) optional. Specifies the axes to reduce over when - differentiating with respect to array-valued arguments. The default, (), - means to differentiate each element of the output with respect to each - element of the argument. If the argument is an array, this argument controls - how many axes the output of grad has. - allow_int: (bool) optional. Whether to allow differentiating with respect to - integer valued inputs. The gradient of an integer input will have a trivial - vector-space dtype (float0). Default False. - holomorphic: (bool) optional. Whether fun is promised to be holomorphic. - Default False. - grad_states: (State, Sequence[State], Dict[str, State]) optional. The variables - in fun to take their gradients. - fun: the scalar-valued function to be differentiated. - argnums: (int or tuple of ints) optional. Specifies which positional - argument(s) to differentiate with respect to. - has_aux: (bool) optional. Indicates whether fun returns a pair where the - first element is considered the output of the mathematical function to be - differentiated and the second element is auxiliary data. Default False. - return_value: (bool) optional. Indicates whether to return the value of the - function along with the gradient. Default False. + fun: callable. the scalar-valued function to be differentiated. + allow_int: (bool) optional. Whether to allow differentiating with respect to + integer valued inputs. The gradient of an integer input will have a trivial + vector-space dtype (float0). Default False. + holomorphic: (bool) optional. Whether fun is promised to be holomorphic. + Default False. + grad_states: (State, Sequence[State], Dict[str, State]) optional. The variables + in fun to take their gradients. + fun: the scalar-valued function to be differentiated. + argnums: (int or tuple of ints) optional. Specifies which positional + argument(s) to differentiate with respect to. + has_aux: (bool) optional. Indicates whether fun returns a pair where the + first element is considered the output of the mathematical function to be + differentiated and the second element is auxiliary data. Default False. + return_value: (bool) optional. Indicates whether to return the value of the + function along with the gradient. Default False. + unit_aware: (bool) optional. Whether to return the gradient in the unit-aware + mode. Default False. Returns: A function which computes the gradient of fun. The function takes the same @@ -367,26 +348,24 @@ def grad( if isinstance(fun, Missing): def transform(fun) -> GradientTransform: return GradientTransform(target=fun, - transform=jax.grad, + transform=u.autograd.grad if unit_aware else jax.grad, grad_states=grad_states, argnums=argnums, return_value=return_value, has_aux=False if has_aux is None else has_aux, transform_params=dict(holomorphic=holomorphic, - allow_int=allow_int, - reduce_axes=reduce_axes)) + allow_int=allow_int)) return transform return GradientTransform(target=fun, - transform=jax.grad, + transform=u.autograd.grad if unit_aware else jax.grad, grad_states=grad_states, argnums=argnums, return_value=return_value, has_aux=False if has_aux is None else has_aux, transform_params=dict(holomorphic=holomorphic, - allow_int=allow_int, - reduce_axes=reduce_axes)) + allow_int=allow_int)) grad.__doc__ = grad.__doc__ % _doc_of_return @@ -399,6 +378,7 @@ def vector_grad( argnums: Optional[Union[int, Sequence[int]]] = None, return_value: bool = False, has_aux: Optional[bool] = None, + unit_aware: bool = False, ) -> GradientTransform | Callable[[Callable], GradientTransform]: """Take vector-valued gradients for function ``func``. @@ -410,28 +390,30 @@ def vector_grad( Parameters ---------- func: Callable - Function whose gradient is to be computed. + Function whose gradient is to be computed. grad_states : optional, ArrayType, sequence of ArrayType, dict - The variables in ``func`` to take their gradients. + The variables in ``func`` to take their gradients. has_aux: optional, bool - Indicates whether ``fun`` returns a pair where the - first element is considered the output of the mathematical function to be - differentiated and the second element is auxiliary data. Default False. + Indicates whether ``fun`` returns a pair where the + first element is considered the output of the mathematical function to be + differentiated and the second element is auxiliary data. Default False. return_value : bool - Whether return the loss value. + Whether return the loss value. argnums: Optional, integer or sequence of integers. Specifies which - positional argument(s) to differentiate with respect to (default ``0``). + positional argument(s) to differentiate with respect to (default ``0``). + unit_aware: (bool) optional. Whether to return the gradient in the unit-aware + mode. Default False. Returns ------- func : GradientTransform - The vector gradient function. + The vector gradient function. """ if isinstance(func, Missing): def transform(fun) -> GradientTransform: return GradientTransform(target=fun, - transform=functional_vector_grad, + transform=partial(u.autograd.vector_grad, unit_aware=unit_aware), grad_states=grad_states, argnums=argnums, return_value=return_value, @@ -441,7 +423,7 @@ def transform(fun) -> GradientTransform: else: return GradientTransform(target=func, - transform=functional_vector_grad, + transform=partial(u.autograd.vector_grad, unit_aware=unit_aware), grad_states=grad_states, argnums=argnums, return_value=return_value, @@ -460,6 +442,7 @@ def jacrev( return_value: bool = False, holomorphic: bool = False, allow_int: bool = False, + unit_aware: bool = False, ) -> GradientTransform: """ Extending automatic Jacobian (reverse-mode) of ``func`` to classes. @@ -473,25 +456,28 @@ def jacrev( Parameters ---------- - fun: Function whose Jacobian is to be computed. + fun: Callable + Function whose Jacobian is to be computed. grad_states : optional, ArrayType, sequence of ArrayType, dict - The variables in ``func`` to take their gradients. + The variables in ``func`` to take their gradients. has_aux: optional, bool - Indicates whether ``fun`` returns a pair where the - first element is considered the output of the mathematical function to be - differentiated and the second element is auxiliary data. Default False. + Indicates whether ``fun`` returns a pair where the + first element is considered the output of the mathematical function to be + differentiated and the second element is auxiliary data. Default False. return_value : bool - Whether return the loss value. + Whether return the loss value. argnums: Optional, integer or sequence of integers. - Specifies which - positional argument(s) to differentiate with respect to (default ``0``). + Specifies which + positional argument(s) to differentiate with respect to (default ``0``). holomorphic: Optional, bool. - Indicates whether ``fun`` is promised to be - holomorphic. Default False. + Indicates whether ``fun`` is promised to be + holomorphic. Default False. allow_int: Optional, bool. - Whether to allow differentiating with - respect to integer valued inputs. The gradient of an integer input will - have a trivial vector-space dtype (float0). Default False. + Whether to allow differentiating with + respect to integer valued inputs. The gradient of an integer input will + have a trivial vector-space dtype (float0). Default False. + unit_aware: (bool) optional. Whether to return the gradient in the unit-aware + mode. Default False. Returns ------- @@ -505,7 +491,8 @@ def jacrev( return_value=return_value, has_aux=False if has_aux is None else has_aux, transform_params=dict(holomorphic=holomorphic, - allow_int=allow_int)) + allow_int=allow_int, + unit_aware=unit_aware, )) jacrev.__doc__ = jacrev.__doc__ % _doc_of_return @@ -521,6 +508,7 @@ def jacfwd( has_aux: Optional[bool] = None, return_value: bool = False, holomorphic: bool = False, + unit_aware: bool = False, ) -> GradientTransform: """Extending automatic Jacobian (forward-mode) of ``func`` to classes. @@ -542,9 +530,11 @@ def jacfwd( return_value : bool Whether return the loss value. argnums: Optional, integer or sequence of integers. Specifies which - positional argument(s) to differentiate with respect to (default ``0``). + positional argument(s) to differentiate with respect to (default ``0``). holomorphic: Optional, bool. Indicates whether ``fun`` is promised to be - holomorphic. Default False. + holomorphic. Default False. + unit_aware: (bool) optional. Whether to return the gradient in the unit-aware + mode. Default False. Returns ------- @@ -558,7 +548,8 @@ def jacfwd( argnums=argnums, return_value=return_value, has_aux=False if has_aux is None else has_aux, - transform_params=dict(holomorphic=holomorphic)) + transform_params=dict(holomorphic=holomorphic, + unit_aware=unit_aware)) jacfwd.__doc__ = jacfwd.__doc__ % _doc_of_return @@ -569,9 +560,10 @@ def hessian( func: Callable, grad_states: Optional[Union[State, Sequence[State], Dict[str, State]]] = None, argnums: Optional[Union[int, Sequence[int]]] = None, - has_aux: bool = False, return_value: bool = False, holomorphic: bool = False, + has_aux: Optional[bool] = None, + unit_aware: bool = False, ) -> GradientTransform: """ Hessian of ``func`` as a dense array. @@ -593,6 +585,12 @@ def hessian( Indicates whether ``fun`` is promised to be holomorphic. Default False. return_value : bool Whether return the hessian values. + has_aux: Optional, bool + Indicates whether ``fun`` returns a pair where the first element is considered + the output of the mathematical function to be differentiated and the second + element is auxiliary data. Default False. + unit_aware: (bool) optional. Whether to return the gradient in the unit-aware + mode. Default False. Returns ------- @@ -600,7 +598,7 @@ def hessian( The transformed object. """ return GradientTransform(target=func, - transform=jax.hessian, + transform=u.autograd.hessian if unit_aware else jax.hessian, grad_states=grad_states, argnums=argnums, return_value=return_value, diff --git a/brainstate/augment/_autograd_test.py b/brainstate/augment/_autograd_test.py index e412296..cd8c1c0 100644 --- a/brainstate/augment/_autograd_test.py +++ b/brainstate/augment/_autograd_test.py @@ -19,6 +19,7 @@ import unittest from pprint import pprint +import brainunit as u import jax import jax.numpy as jnp import pytest @@ -608,6 +609,8 @@ def __call__(self, ): br = bst.augment.jacrev(t, grad_states=[t.x, t.y])() self.assertTrue((br[0] == _jr[0]).all()) self.assertTrue((br[1] == _jr[1]).all()) + + # # def test_jacfwd1(self): # def f1(x, y): @@ -1191,3 +1194,97 @@ def __call__(self, ): # self.assertTrue(file.read().strip() == expect_res.strip()) # # + + +class TestUnitAwareGrad(unittest.TestCase): + def test_grad1(self): + def f(x): + return u.math.sum(x ** 2) + + x = jnp.array([1., 2., 3.]) * u.ms + g = bst.augment.grad(f, unit_aware=True)(x) + self.assertTrue(u.math.allclose(g, 2 * x)) + + def test_vector_grad1(self): + def f(x): + return x ** 3 + + x = jnp.array([1., 2., 3.]) * u.ms + g = bst.augment.vector_grad(f, unit_aware=True)(x) + self.assertTrue(u.math.allclose(g, 3 * x ** 2)) + + def test_jacrev1(self): + def f(x, y): + return u.math.asarray([x[0] * y[0], + 5 * x[2] * y[1], + 4 * x[1] ** 2, ]) + + _x = jnp.array([1., 2., 3.]) * u.ms + _y = jnp.array([10., 5.]) * u.ms + + g = bst.augment.jacrev(f, unit_aware=True, argnums=(0, 1))(_x, _y) + self.assertTrue( + u.math.allclose( + g[0], + u.math.asarray([ + [10., 0., 0.], + [0., 0., 25.], + [0., 16., 0.] + ]) * u.ms + ) + ) + + self.assertTrue( + u.math.allclose( + g[1], + u.math.asarray([ + [1., 0.], + [0., 15.], + [0., 0.] + ]) * u.ms + ) + ) + + def test_jacfwd1(self): + def f(x, y): + return u.math.asarray([x[0] * y[0], + 5 * x[2] * y[1], + 4 * x[1] ** 2, ]) + + _x = jnp.array([1., 2., 3.]) * u.ms + _y = jnp.array([10., 5.]) * u.ms + + g = bst.augment.jacfwd(f, unit_aware=True, argnums=(0, 1))(_x, _y) + self.assertTrue( + u.math.allclose( + g[0], + u.math.asarray([ + [10., 0., 0.], + [0., 0., 25.], + [0., 16., 0.] + ]) * u.ms + ) + ) + + self.assertTrue( + u.math.allclose( + g[1], + u.math.asarray([ + [1., 0.], + [0., 15.], + [0., 0.] + ]) * u.ms + ) + ) + + def test_hessian(self): + unit = u.ms + + def scalar_function(x): + return x ** 3 + 3 * x * unit * unit + 2 * unit * unit * unit + + hess = bst.augment.hessian(scalar_function, unit_aware=True) + x = jnp.array(1.0) * unit + res = hess(x) + expected_hessian = jnp.array([[6.0]]) * unit + assert u.math.allclose(res, expected_hessian) diff --git a/brainstate/event/_csr_benchmark.py b/brainstate/event/_csr_benchmark.py new file mode 100644 index 0000000..23b09eb --- /dev/null +++ b/brainstate/event/_csr_benchmark.py @@ -0,0 +1,14 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== diff --git a/brainstate/nn/_interaction/_linear.py b/brainstate/nn/_interaction/_linear.py index 82d4223..a7861b7 100644 --- a/brainstate/nn/_interaction/_linear.py +++ b/brainstate/nn/_interaction/_linear.py @@ -79,7 +79,7 @@ def update(self, x): weight = params['weight'] if self.w_mask is not None: weight = weight * self.w_mask - y = u.math.dot(x, weight) + y = u.linalg.dot(x, weight) if 'bias' in params: y = y + params['bias'] return y @@ -192,7 +192,7 @@ def update(self, x): w = functional.weight_standardization(w, self.eps, params.get('gain', None)) if self.w_mask is not None: w = w * self.w_mask - y = u.math.dot(x, w) + y = u.linalg.dot(x, w) if 'bias' in params: y = y + params['bias'] return y diff --git a/pyproject.toml b/pyproject.toml index cd9987e..10035f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,7 +52,7 @@ dependencies = [ 'jax', 'jaxlib', 'numpy', - 'brainunit>=0.0.2', + 'brainunit>=0.0.3', ] dynamic = ['version'] diff --git a/requirements.txt b/requirements.txt index 3adbee4..666390b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ numpy jax jaxlib -brainunit>=0.0.2 +brainunit>=0.0.3 diff --git a/setup.py b/setup.py index 5c370d1..118fcff 100644 --- a/setup.py +++ b/setup.py @@ -58,7 +58,7 @@ author_email='chao.brain@qq.com', packages=packages, python_requires='>=3.9', - install_requires=['numpy>=1.15', 'jax', 'tqdm', 'brainunit>=0.0.2'], + install_requires=['numpy>=1.15', 'jax', 'tqdm', 'brainunit>=0.0.3'], url='https://github.com/chaobrain/brainstate', project_urls={ "Bug Tracker": "https://github.com/chaobrain/brainstate/issues",