Skip to content

Commit

Permalink
[random] random key regeneration only when the value is ArrayImpl
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Aug 28, 2023
1 parent 3f9ca59 commit 9655cb3
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 13 deletions.
5 changes: 1 addition & 4 deletions brainpy/_src/integrators/sde/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,8 @@ def __init__(
self.intg_type = intg_type # integral type
self.wiener_type = wiener_type # wiener process type

# random seed
self.rng = bm.random.default_rng(clone=False)

# code scope
self.code_scope = {constants.F: f, constants.G: g, 'math': jnp, 'random': self.rng}
self.code_scope = {constants.F: f, constants.G: g, 'math': jnp, 'random': bm.random.DEFAULT}
# code lines
self.func_name = f_names(f)
self.code_lines = [f'def {self.func_name}({", ".join(self.arguments)}):']
Expand Down
14 changes: 7 additions & 7 deletions brainpy/_src/integrators/sde/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,10 @@ def step(self, *args, **kwargs):
if diffusions[key] is not None:
shape = jnp.shape(all_args[key])
if self.wiener_type == constants.SCALAR_WIENER:
integral += diffusions[key] * self.rng.randn(*shape) * jnp.sqrt(dt)
integral += diffusions[key] * bm.random.randn(*shape) * jnp.sqrt(dt)
else:
shape += jnp.shape(diffusions[key])[-1:]
integral += jnp.sum(diffusions[key] * self.rng.randn(*shape), axis=-1) * jnp.sqrt(dt)
integral += jnp.sum(diffusions[key] * bm.random.randn(*shape), axis=-1) * jnp.sqrt(dt)
integrals.append(integral)

else:
Expand All @@ -156,7 +156,7 @@ def step(self, *args, **kwargs):
noise_shape = jnp.shape(diffusions[key])
self._check_vector_wiener_dim(noise_shape, shape)
shape += noise_shape[-1:]
noise = self.rng.randn(*shape)
noise = bm.random.randn(*shape)
all_noises[key] = noise * jnp.sqrt(dt)
if self.wiener_type == constants.VECTOR_WIENER:
y_bar = all_args[key] + jnp.sum(diffusions[key] * noise, axis=-1)
Expand Down Expand Up @@ -358,7 +358,7 @@ def step(self, *args, **kwargs):
noise_shape = jnp.shape(diffusions[key])
self._check_vector_wiener_dim(noise_shape, shape)
shape += noise_shape[-1:]
noise = self.rng.randn(*shape) * jnp.sqrt(dt)
noise = bm.random.randn(*shape) * jnp.sqrt(dt)
if self.wiener_type == constants.VECTOR_WIENER:
integral += jnp.sum(diffusions[key] * noise, axis=-1)
else:
Expand Down Expand Up @@ -483,7 +483,7 @@ def step(self, *args, **kwargs):
noise_shape = jnp.shape(diffusions[key])
self._check_vector_wiener_dim(noise_shape, shape)
shape += noise_shape[-1:]
noise = self.rng.randn(*shape) * jnp.sqrt(dt)
noise = bm.random.randn(*shape) * jnp.sqrt(dt)
if self.wiener_type == constants.VECTOR_WIENER:
integral += jnp.sum(diffusions[key] * noise, axis=-1)
else:
Expand Down Expand Up @@ -597,9 +597,9 @@ def integral_func(*args, **kwargs):
noise_shape = jnp.shape(diffusion)
self._check_vector_wiener_dim(noise_shape, shape)
shape += noise_shape[-1:]
diffusion = jnp.sum(diffusion * self.rng.randn(*shape), axis=-1)
diffusion = jnp.sum(diffusion * bm.random.randn(*shape), axis=-1)
else:
diffusion = diffusion * self.rng.randn(*shape)
diffusion = diffusion * bm.random.randn(*shape)
r += diffusion * jnp.sqrt(params_in[constants.DT])
# final result
results.append(r)
Expand Down
4 changes: 3 additions & 1 deletion brainpy/_src/math/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from jax import lax, jit, vmap, numpy as jnp, random as jr, core, dtypes
from jax.experimental.host_callback import call
from jax.tree_util import register_pytree_node_class
from jax._src.array import ArrayImpl

from brainpy.check import jit_error
from .compat_numpy import shape
Expand Down Expand Up @@ -491,7 +492,8 @@ def __repr__(self) -> str:

@property
def value(self):
if hasattr(self._value, 'is_deleted') and self._value.is_deleted():
if isinstance(self._value, ArrayImpl):
if self._value.is_deleted():
self.seed()
self._append_to_stack()
return self._value
Expand Down
1 change: 0 additions & 1 deletion tests/simulation/test_net_rate_SL.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import unittest
import os


show = False


Expand Down

0 comments on commit 9655cb3

Please sign in to comment.