Skip to content

Commit

Permalink
Merge pull request #461 from brainpy/updates
Browse files Browse the repository at this point in the history
Creat random key automatically when it is detected
  • Loading branch information
chaoming0625 authored Aug 29, 2023
2 parents 57ce2bc + 9655cb3 commit efb450c
Show file tree
Hide file tree
Showing 10 changed files with 270 additions and 28 deletions.
2 changes: 1 addition & 1 deletion brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.4.4"
__version__ = "2.4.4.post2"

# fundamental supporting modules
from brainpy import errors, check, tools
Expand Down
5 changes: 4 additions & 1 deletion brainpy/_src/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ def __getitem__(self, item):

def get_shargs(self) -> DotDict:
"""Get all shared arguments in the global context."""
return self._arguments.copy()
shs = self._arguments.copy()
if 'dt' not in shs:
shs['dt'] = self.dt
return shs

def clear_shargs(self, *args) -> None:
"""Clear all shared arguments in the global context."""
Expand Down
16 changes: 8 additions & 8 deletions brainpy/_src/dyn/projections/aligns.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def __init__(
if out_label is None:
out_name = self.name
else:
out_name = f'{out_label}-{self.name}'
out_name = f'{out_label} // {self.name}'
post.add_inp_fun(out_name, out_cls)
post.add_bef_update(self._post_repr, _AlignPost(syn_cls, out_cls))

Expand Down Expand Up @@ -334,7 +334,7 @@ def __init__(
if out_label is None:
out_name = self.name
else:
out_name = f'{out_label}-{self.name}'
out_name = f'{out_label} // {self.name}'
post.add_inp_fun(out_name, out_cls)
post.add_bef_update(self._post_repr, _AlignPost(syn_cls, out_cls))

Expand Down Expand Up @@ -417,7 +417,7 @@ def __init__(
if out_label is None:
out_name = self.name
else:
out_name = f'{out_label}-{self.name}'
out_name = f'{out_label} // {self.name}'
post.add_inp_fun(out_name, out)
post.add_bef_update(self.name, _AlignPost(syn, out))

Expand Down Expand Up @@ -534,7 +534,7 @@ def __init__(
if out_label is None:
out_name = self.name
else:
out_name = f'{out_label}-{self.name}'
out_name = f'{out_label} // {self.name}'
post.add_inp_fun(out_name, out)

# references
Expand Down Expand Up @@ -651,7 +651,7 @@ def __init__(
if out_label is None:
out_name = self.name
else:
out_name = f'{out_label}-{self.name}'
out_name = f'{out_label} // {self.name}'
post.add_inp_fun(out_name, out)

# references
Expand Down Expand Up @@ -774,7 +774,7 @@ def __init__(
if out_label is None:
out_name = self.name
else:
out_name = f'{out_label}-{self.name}'
out_name = f'{out_label} // {self.name}'
post.add_inp_fun(out_name, out)

# references
Expand Down Expand Up @@ -886,7 +886,7 @@ def __init__(
if out_label is None:
out_name = self.name
else:
out_name = f'{out_label}-{self.name}'
out_name = f'{out_label} // {self.name}'
post.add_inp_fun(out_name, out)

# references
Expand Down Expand Up @@ -1002,7 +1002,7 @@ def __init__(
if out_label is None:
out_name = self.name
else:
out_name = f'{out_label}-{self.name}'
out_name = f'{out_label} // {self.name}'
post.add_inp_fun(out_name, out)

# references
Expand Down
10 changes: 5 additions & 5 deletions brainpy/_src/dyn/rates/populations.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ def __init__(
input_var: bool = True,
):
super().__init__(size=size,
name=name,
keep_size=keep_size,
mode=mode)
name=name,
keep_size=keep_size,
mode=mode)

# model parameters
self.alpha = parameter(alpha, self.varshape, allow_none=False)
Expand Down Expand Up @@ -1025,8 +1025,8 @@ def __init__(
self.e = variable(e_initializer, self.mode, self.varshape) # Firing rate of excitatory population
self.i = variable(i_initializer, self.mode, self.varshape) # Firing rate of inhibitory population
if self.input_var:
self.Ie = variable(bm.zeros, self.mode, self.varshape) # Input of excitaory population
self.Ii = variable(bm.zeros, self.mode, self.varshape) # Input of inhibitory population
self.Ie = variable(bm.zeros, self.mode, self.varshape) # Input of excitaory population
self.Ii = variable(bm.zeros, self.mode, self.varshape) # Input of inhibitory population

def reset(self, batch_size=None):
self.reset_state(batch_size)
Expand Down
5 changes: 1 addition & 4 deletions brainpy/_src/integrators/sde/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,8 @@ def __init__(
self.intg_type = intg_type # integral type
self.wiener_type = wiener_type # wiener process type

# random seed
self.rng = bm.random.default_rng(clone=False)

# code scope
self.code_scope = {constants.F: f, constants.G: g, 'math': jnp, 'random': self.rng}
self.code_scope = {constants.F: f, constants.G: g, 'math': jnp, 'random': bm.random.DEFAULT}
# code lines
self.func_name = f_names(f)
self.code_lines = [f'def {self.func_name}({", ".join(self.arguments)}):']
Expand Down
14 changes: 7 additions & 7 deletions brainpy/_src/integrators/sde/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ def step(self, *args, **kwargs):
if diffusions[key] is not None:
shape = jnp.shape(all_args[key])
if self.wiener_type == constants.SCALAR_WIENER:
integral += diffusions[key] * self.rng.randn(*shape) * jnp.sqrt(dt)
integral += diffusions[key] * bm.random.randn(*shape) * jnp.sqrt(dt)
else:
shape += jnp.shape(diffusions[key])[-1:]
integral += jnp.sum(diffusions[key] * self.rng.randn(*shape), axis=-1) * jnp.sqrt(dt)
integral += jnp.sum(diffusions[key] * bm.random.randn(*shape), axis=-1) * jnp.sqrt(dt)
integrals.append(integral)

else:
Expand All @@ -156,7 +156,7 @@ def step(self, *args, **kwargs):
noise_shape = jnp.shape(diffusions[key])
self._check_vector_wiener_dim(noise_shape, shape)
shape += noise_shape[-1:]
noise = self.rng.randn(*shape)
noise = bm.random.randn(*shape)
all_noises[key] = noise * jnp.sqrt(dt)
if self.wiener_type == constants.VECTOR_WIENER:
y_bar = all_args[key] + jnp.sum(diffusions[key] * noise, axis=-1)
Expand Down Expand Up @@ -358,7 +358,7 @@ def step(self, *args, **kwargs):
noise_shape = jnp.shape(diffusions[key])
self._check_vector_wiener_dim(noise_shape, shape)
shape += noise_shape[-1:]
noise = self.rng.randn(*shape) * jnp.sqrt(dt)
noise = bm.random.randn(*shape) * jnp.sqrt(dt)
if self.wiener_type == constants.VECTOR_WIENER:
integral += jnp.sum(diffusions[key] * noise, axis=-1)
else:
Expand Down Expand Up @@ -483,7 +483,7 @@ def step(self, *args, **kwargs):
noise_shape = jnp.shape(diffusions[key])
self._check_vector_wiener_dim(noise_shape, shape)
shape += noise_shape[-1:]
noise = self.rng.randn(*shape) * jnp.sqrt(dt)
noise = bm.random.randn(*shape) * jnp.sqrt(dt)
if self.wiener_type == constants.VECTOR_WIENER:
integral += jnp.sum(diffusions[key] * noise, axis=-1)
else:
Expand Down Expand Up @@ -597,9 +597,9 @@ def integral_func(*args, **kwargs):
noise_shape = jnp.shape(diffusion)
self._check_vector_wiener_dim(noise_shape, shape)
shape += noise_shape[-1:]
diffusion = jnp.sum(diffusion * self.rng.randn(*shape), axis=-1)
diffusion = jnp.sum(diffusion * bm.random.randn(*shape), axis=-1)
else:
diffusion = diffusion * self.rng.randn(*shape)
diffusion = diffusion * bm.random.randn(*shape)
r += diffusion * jnp.sqrt(params_in[constants.DT])
# final result
results.append(r)
Expand Down
9 changes: 9 additions & 0 deletions brainpy/_src/math/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from jax import lax, jit, vmap, numpy as jnp, random as jr, core, dtypes
from jax.experimental.host_callback import call
from jax.tree_util import register_pytree_node_class
from jax._src.array import ArrayImpl

from brainpy.check import jit_error
from .compat_numpy import shape
Expand Down Expand Up @@ -489,6 +490,14 @@ def __repr__(self) -> str:
name = self.__class__.__name__
return f'{name}(key={print_code[i:]})'

@property
def value(self):
if isinstance(self._value, ArrayImpl):
if self._value.is_deleted():
self.seed()
self._append_to_stack()
return self._value

# ------------------- #
# seed and random key #
# ------------------- #
Expand Down
9 changes: 8 additions & 1 deletion brainpy/_src/math/tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import numpy.random as nr

import brainpy.math as bm
import brainpy.math.random as br
Expand Down Expand Up @@ -548,3 +547,11 @@ def test_t2(self):
br.seed()
a = bm.random.t([1., 2.], size=None)
self.assertTupleEqual(a.shape, (2,))


class TestRandomKey(unittest.TestCase):
def test_clear_memory(self):
bm.random.split_key()
bm.clear_buffer_memory()
print(bm.random.DEFAULT.value)
self.assertTrue(isinstance(bm.random.DEFAULT.value, np.ndarray))
Loading

0 comments on commit efb450c

Please sign in to comment.