Skip to content

Commit

Permalink
Support physical unit-aware gradient computation using `brainunit.aut…
Browse files Browse the repository at this point in the history
…ograd` (#42)

* csr benchmark

* fix xla custom op bugs

* update examples

* fix memory access error when n_conn is small

* fix hashable bug

* update examples

* fix bug

* support physical unit-aware gradient using `brainunit.autograd`

* update requirements

* use `brainunit.linalg.dot` rather than `brainunit.math.dot`
  • Loading branch information
chaoming0625 authored Nov 25, 2024
1 parent aa6421a commit a61ea74
Show file tree
Hide file tree
Showing 8 changed files with 229 additions and 120 deletions.
2 changes: 1 addition & 1 deletion brainstate/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ def recovery_original_values(self) -> None:
"""
for st, val in zip(self.states, self._original_state_values):
# internal use
st._value = val
st.restore_value(val)

def merge(self, *traces) -> 'StateTraceStack':
"""
Expand Down
226 changes: 112 additions & 114 deletions brainstate/augment/_autograd.py

Large diffs are not rendered by default.

97 changes: 97 additions & 0 deletions brainstate/augment/_autograd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import unittest
from pprint import pprint

import brainunit as u
import jax
import jax.numpy as jnp
import pytest
Expand Down Expand Up @@ -608,6 +609,8 @@ def __call__(self, ):
br = bst.augment.jacrev(t, grad_states=[t.x, t.y])()
self.assertTrue((br[0] == _jr[0]).all())
self.assertTrue((br[1] == _jr[1]).all())


#
# def test_jacfwd1(self):
# def f1(x, y):
Expand Down Expand Up @@ -1191,3 +1194,97 @@ def __call__(self, ):
# self.assertTrue(file.read().strip() == expect_res.strip())
#
#


class TestUnitAwareGrad(unittest.TestCase):
def test_grad1(self):
def f(x):
return u.math.sum(x ** 2)

x = jnp.array([1., 2., 3.]) * u.ms
g = bst.augment.grad(f, unit_aware=True)(x)
self.assertTrue(u.math.allclose(g, 2 * x))

def test_vector_grad1(self):
def f(x):
return x ** 3

x = jnp.array([1., 2., 3.]) * u.ms
g = bst.augment.vector_grad(f, unit_aware=True)(x)
self.assertTrue(u.math.allclose(g, 3 * x ** 2))

def test_jacrev1(self):
def f(x, y):
return u.math.asarray([x[0] * y[0],
5 * x[2] * y[1],
4 * x[1] ** 2, ])

_x = jnp.array([1., 2., 3.]) * u.ms
_y = jnp.array([10., 5.]) * u.ms

g = bst.augment.jacrev(f, unit_aware=True, argnums=(0, 1))(_x, _y)
self.assertTrue(
u.math.allclose(
g[0],
u.math.asarray([
[10., 0., 0.],
[0., 0., 25.],
[0., 16., 0.]
]) * u.ms
)
)

self.assertTrue(
u.math.allclose(
g[1],
u.math.asarray([
[1., 0.],
[0., 15.],
[0., 0.]
]) * u.ms
)
)

def test_jacfwd1(self):
def f(x, y):
return u.math.asarray([x[0] * y[0],
5 * x[2] * y[1],
4 * x[1] ** 2, ])

_x = jnp.array([1., 2., 3.]) * u.ms
_y = jnp.array([10., 5.]) * u.ms

g = bst.augment.jacfwd(f, unit_aware=True, argnums=(0, 1))(_x, _y)
self.assertTrue(
u.math.allclose(
g[0],
u.math.asarray([
[10., 0., 0.],
[0., 0., 25.],
[0., 16., 0.]
]) * u.ms
)
)

self.assertTrue(
u.math.allclose(
g[1],
u.math.asarray([
[1., 0.],
[0., 15.],
[0., 0.]
]) * u.ms
)
)

def test_hessian(self):
unit = u.ms

def scalar_function(x):
return x ** 3 + 3 * x * unit * unit + 2 * unit * unit * unit

hess = bst.augment.hessian(scalar_function, unit_aware=True)
x = jnp.array(1.0) * unit
res = hess(x)
expected_hessian = jnp.array([[6.0]]) * unit
assert u.math.allclose(res, expected_hessian)
14 changes: 14 additions & 0 deletions brainstate/event/_csr_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
4 changes: 2 additions & 2 deletions brainstate/nn/_interaction/_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def update(self, x):
weight = params['weight']
if self.w_mask is not None:
weight = weight * self.w_mask
y = u.math.dot(x, weight)
y = u.linalg.dot(x, weight)
if 'bias' in params:
y = y + params['bias']
return y
Expand Down Expand Up @@ -192,7 +192,7 @@ def update(self, x):
w = functional.weight_standardization(w, self.eps, params.get('gain', None))
if self.w_mask is not None:
w = w * self.w_mask
y = u.math.dot(x, w)
y = u.linalg.dot(x, w)
if 'bias' in params:
y = y + params['bias']
return y
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ dependencies = [
'jax',
'jaxlib',
'numpy',
'brainunit>=0.0.2',
'brainunit>=0.0.3',
]

dynamic = ['version']
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
numpy
jax
jaxlib
brainunit>=0.0.2
brainunit>=0.0.3
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
author_email='[email protected]',
packages=packages,
python_requires='>=3.9',
install_requires=['numpy>=1.15', 'jax', 'tqdm', 'brainunit>=0.0.2'],
install_requires=['numpy>=1.15', 'jax', 'tqdm', 'brainunit>=0.0.3'],
url='https://github.com/chaobrain/brainstate',
project_urls={
"Bug Tracker": "https://github.com/chaobrain/brainstate/issues",
Expand Down

0 comments on commit a61ea74

Please sign in to comment.