From 6c0071aba319fc85af8cc1bd84d2fb70221345b8 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 15 Jan 2023 22:00:15 +0800 Subject: [PATCH 1/8] unify array operation return (brainpy Array or jax Array) with a single function --- brainpy/_src/math/_utils.py | 22 +--- brainpy/_src/math/ndarray.py | 204 +++++++++++++++++++---------------- brainpy/_src/math/random.py | 3 +- brainpy/math/others.py | 2 +- 4 files changed, 115 insertions(+), 116 deletions(-) diff --git a/brainpy/_src/math/_utils.py b/brainpy/_src/math/_utils.py index 7a4950a97..6c75126af 100644 --- a/brainpy/_src/math/_utils.py +++ b/brainpy/_src/math/_utils.py @@ -3,33 +3,15 @@ import functools from typing import Callable -import jax from jax.tree_util import tree_map -from .ndarray import Array - -__all__ = [ - 'npfun_returns_bparray' -] +from .ndarray import Array, _return def _as_jax_array_(obj): return obj.value if isinstance(obj, Array) else obj -def _return(a): - return Array(a) if isinstance(a, jax.Array) and a.ndim > 1 else a - - -_return_bp_array = True - - -def npfun_returns_bparray(mode: bool): - global _return_bp_array - assert isinstance(mode, bool) - _return_bp_array = mode - - def wraps(fun: Callable): """Specialized version of functools.wraps for wrapping numpy functions. @@ -60,7 +42,7 @@ def new_fun(*args, **kwargs): if len(kwargs): kwargs = tree_map(_as_jax_array_, kwargs, is_leaf=_is_leaf) r = fun(*args, **kwargs) - return tree_map(_return, r) if _return_bp_array else r + return tree_map(_return, r) new_fun.__doc__ = getattr(fun, "__doc__", None) diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py index 2c1ca960a..f715f843a 100644 --- a/brainpy/_src/math/ndarray.py +++ b/brainpy/_src/math/ndarray.py @@ -5,6 +5,7 @@ import operator from typing import Optional, Tuple as TupleType +import jax import numpy as np from jax import numpy as jnp from jax.dtypes import canonicalize_dtype @@ -91,6 +92,30 @@ def _check_input_array(array): return array +def _return(a): + if _return_bp_array: + if isinstance(a, jax.Array): + if a.ndim > 1: + return Array(a) + return a + + +def is_return_bparray(): + return _return_bp_array + + +_return_bp_array = True + +def _as_jax_array_(obj): + return obj.value if isinstance(obj, Array) else obj + + +def npfun_returns_bparray(mode: bool): + global _return_bp_array + assert isinstance(mode, bool) + _return_bp_array = mode + + class Array(object): """Multiple-dimensional array in BrainPy. """ @@ -160,7 +185,7 @@ def imag(self): @property def real(self): - return Array(self.value.real) + return _return(self.value.real) @property def size(self): @@ -168,7 +193,7 @@ def size(self): @property def T(self): - return Array(self.value.T) + return _return(self.value.T) # ----------------------- # # Python inherent methods # @@ -179,8 +204,6 @@ def __repr__(self) -> str: name = self.__class__.__name__ if 'DeviceArray' in print_code: replace_name = 'DeviceArray' - elif 'Array' in print_code: - replace_name = 'Array' else: replace_name = '' if replace_name: @@ -264,40 +287,40 @@ def __len__(self) -> int: return len(self._value) def __neg__(self): - return Array(self._value.__neg__()) + return _return(self._value.__neg__()) def __pos__(self): - return Array(self._value.__pos__()) + return _return(self._value.__pos__()) def __abs__(self): - return Array(self._value.__abs__()) + return _return(self._value.__abs__()) def __invert__(self): - return Array(self._value.__invert__()) + return _return(self._value.__invert__()) def __eq__(self, oc): - return Array(self._value == _check_input_array(oc)) + return _return(self._value == _check_input_array(oc)) def __ne__(self, oc): - return Array(self._value != _check_input_array(oc)) + return _return(self._value != _check_input_array(oc)) def __lt__(self, oc): - return Array(self._value < _check_input_array(oc)) + return _return(self._value < _check_input_array(oc)) def __le__(self, oc): - return Array(self._value <= _check_input_array(oc)) + return _return(self._value <= _check_input_array(oc)) def __gt__(self, oc): - return Array(self._value > _check_input_array(oc)) + return _return(self._value > _check_input_array(oc)) def __ge__(self, oc): - return Array(self._value >= _check_input_array(oc)) + return _return(self._value >= _check_input_array(oc)) def __add__(self, oc): - return Array(self._value + _check_input_array(oc)) + return _return(self._value + _check_input_array(oc)) def __radd__(self, oc): - return Array(self._value + _check_input_array(oc)) + return _return(self._value + _check_input_array(oc)) def __iadd__(self, oc): # a += b @@ -307,10 +330,10 @@ def __iadd__(self, oc): return self def __sub__(self, oc): - return Array(self._value - _check_input_array(oc)) + return _return(self._value - _check_input_array(oc)) def __rsub__(self, oc): - return Array(_check_input_array(oc) - self._value) + return _return(_check_input_array(oc) - self._value) def __isub__(self, oc): # a -= b @@ -320,10 +343,10 @@ def __isub__(self, oc): return self def __mul__(self, oc): - return Array(self._value * _check_input_array(oc)) + return _return(self._value * _check_input_array(oc)) def __rmul__(self, oc): - return Array(_check_input_array(oc) * self._value) + return _return(_check_input_array(oc) * self._value) def __imul__(self, oc): # a *= b @@ -333,13 +356,13 @@ def __imul__(self, oc): return self def __rdiv__(self, oc): - return Array(_check_input_array(oc) / self._value) + return _return(_check_input_array(oc) / self._value) def __truediv__(self, oc): - return Array(self._value / _check_input_array(oc)) + return _return(self._value / _check_input_array(oc)) def __rtruediv__(self, oc): - return Array(_check_input_array(oc) / self._value) + return _return(_check_input_array(oc) / self._value) def __itruediv__(self, oc): # a /= b @@ -349,10 +372,10 @@ def __itruediv__(self, oc): return self def __floordiv__(self, oc): - return Array(self._value // _check_input_array(oc)) + return _return(self._value // _check_input_array(oc)) def __rfloordiv__(self, oc): - return Array(_check_input_array(oc) // self._value) + return _return(_check_input_array(oc) // self._value) def __ifloordiv__(self, oc): # a //= b @@ -362,16 +385,16 @@ def __ifloordiv__(self, oc): return self def __divmod__(self, oc): - return Array(self._value.__divmod__(_check_input_array(oc))) + return _return(self._value.__divmod__(_check_input_array(oc))) def __rdivmod__(self, oc): - return Array(self._value.__rdivmod__(_check_input_array(oc))) + return _return(self._value.__rdivmod__(_check_input_array(oc))) def __mod__(self, oc): - return Array(self._value % _check_input_array(oc)) + return _return(self._value % _check_input_array(oc)) def __rmod__(self, oc): - return Array(_check_input_array(oc) % self._value) + return _return(_check_input_array(oc) % self._value) def __imod__(self, oc): # a %= b @@ -381,10 +404,10 @@ def __imod__(self, oc): return self def __pow__(self, oc): - return Array(self._value ** _check_input_array(oc)) + return _return(self._value ** _check_input_array(oc)) def __rpow__(self, oc): - return Array(_check_input_array(oc) ** self._value) + return _return(_check_input_array(oc) ** self._value) def __ipow__(self, oc): # a **= b @@ -394,10 +417,10 @@ def __ipow__(self, oc): return self def __matmul__(self, oc): - return Array(self._value @ _check_input_array(oc)) + return _return(self._value @ _check_input_array(oc)) def __rmatmul__(self, oc): - return Array(_check_input_array(oc) @ self._value) + return _return(_check_input_array(oc) @ self._value) def __imatmul__(self, oc): # a @= b @@ -407,10 +430,10 @@ def __imatmul__(self, oc): return self def __and__(self, oc): - return Array(self._value & _check_input_array(oc)) + return _return(self._value & _check_input_array(oc)) def __rand__(self, oc): - return Array(_check_input_array(oc) & self._value) + return _return(_check_input_array(oc) & self._value) def __iand__(self, oc): # a &= b @@ -420,10 +443,10 @@ def __iand__(self, oc): return self def __or__(self, oc): - return Array(self._value | _check_input_array(oc)) + return _return(self._value | _check_input_array(oc)) def __ror__(self, oc): - return Array(_check_input_array(oc) | self._value) + return _return(_check_input_array(oc) | self._value) def __ior__(self, oc): # a |= b @@ -433,10 +456,10 @@ def __ior__(self, oc): return self def __xor__(self, oc): - return Array(self._value ^ _check_input_array(oc)) + return _return(self._value ^ _check_input_array(oc)) def __rxor__(self, oc): - return Array(_check_input_array(oc) ^ self._value) + return _return(_check_input_array(oc) ^ self._value) def __ixor__(self, oc): # a ^= b @@ -446,10 +469,10 @@ def __ixor__(self, oc): return self def __lshift__(self, oc): - return Array(self._value << _check_input_array(oc)) + return _return(self._value << _check_input_array(oc)) def __rlshift__(self, oc): - return Array(_check_input_array(oc) << self._value) + return _return(_check_input_array(oc) << self._value) def __ilshift__(self, oc): # a <<= b @@ -459,10 +482,10 @@ def __ilshift__(self, oc): return self def __rshift__(self, oc): - return Array(self._value >> _check_input_array(oc)) + return _return(self._value >> _check_input_array(oc)) def __rrshift__(self, oc): - return Array(_check_input_array(oc) >> self._value) + return _return(_check_input_array(oc) >> self._value) def __irshift__(self, oc): # a >>= b @@ -472,7 +495,7 @@ def __irshift__(self, oc): return self def __round__(self, ndigits=None): - return Array(self._value.__round__(ndigits)) + return _return(self._value.__round__(ndigits)) # ----------------------- # # JAX methods # @@ -502,28 +525,28 @@ def device_buffer(self): def all(self, axis=None, keepdims=False): """Returns True if all elements evaluate to True.""" r = self.value.all(axis=axis, keepdims=keepdims) - return r if (axis is None or keepdims) else Array(r) + return _return(r) def any(self, axis=None, keepdims=False): """Returns True if any of the elements of a evaluate to True.""" r = self.value.any(axis=axis, keepdims=keepdims) - return r if (axis is None or keepdims) else Array(r) + return _return(r) def argmax(self, axis=None): """Return indices of the maximum values along the given axis.""" - return Array(self.value.argmax(axis=axis)) + return _return(self.value.argmax(axis=axis)) def argmin(self, axis=None): """Return indices of the minimum values along the given axis.""" - return Array(self.value.argmin(axis=axis)) + return _return(self.value.argmin(axis=axis)) def argpartition(self, kth, axis=-1, kind='introselect', order=None): """Returns the indices that would partition this array.""" - return Array(self.value.argpartition(kth=kth, axis=axis, kind=kind, order=order)) + return _return(self.value.argpartition(kth=kth, axis=axis, kind=kind, order=order)) def argsort(self, axis=-1, kind=None, order=None): """Returns the indices that would sort this array.""" - return Array(self.value.argsort(axis=axis, kind=kind, order=order)) + return _return(self.value.argsort(axis=axis, kind=kind, order=order)) def astype(self, dtype): """Copy of the array, cast to a specified type. @@ -533,7 +556,7 @@ def astype(self, dtype): dtype: str, dtype Typecode or data-type to which the array is cast. """ - return Array(self.value.astype(dtype=dtype)) + return _return(self.value.astype(dtype=dtype)) def byteswap(self, inplace=False): """Swap the bytes of the array elements @@ -542,49 +565,47 @@ def byteswap(self, inplace=False): returning a byteswapped array, optionally swapped in-place. Arrays of byte-strings are not swapped. The real and imaginary parts of a complex number are swapped individually.""" - return Array(self.value.byteswap(inplace=inplace)) + return _return(self.value.byteswap(inplace=inplace)) def choose(self, choices, mode='raise'): """Use an index array to construct a new array from a set of choices.""" - choices = choices.value if isinstance(choices, Array) else choices - return Array(self.value.choose(choices=choices, mode=mode)) + return _return(self.value.choose(choices=_as_jax_array_(choices), mode=mode)) def clip(self, min=None, max=None): """Return an array whose values are limited to [min, max]. One of max or min must be given.""" - return Array(self.value.clip(min=min, max=max)) + return _return(self.value.clip(min=min, max=max)) def compress(self, condition, axis=None): """Return selected slices of this array along given axis.""" - condition = condition.value if isinstance(condition, Array) else condition - return Array(self.value.compress(condition=condition, axis=axis)) + return _return(self.value.compress(condition=_as_jax_array_(condition), axis=axis)) def conj(self): """Complex-conjugate all elements.""" - return Array(self.value.conj()) + return _return(self.value.conj()) def conjugate(self): """Return the complex conjugate, element-wise.""" - return Array(self.value.conjugate()) + return _return(self.value.conjugate()) def copy(self): """Return a copy of the array.""" - return Array(self.value.copy()) + return _return(self.value.copy()) def cumprod(self, axis=None, dtype=None): """Return the cumulative product of the elements along the given axis.""" - return Array(self.value.cumprod(axis=axis, dtype=dtype)) + return _return(self.value.cumprod(axis=axis, dtype=dtype)) def cumsum(self, axis=None, dtype=None): """Return the cumulative sum of the elements along the given axis.""" - return Array(self.value.cumsum(axis=axis, dtype=dtype)) + return _return(self.value.cumsum(axis=axis, dtype=dtype)) def diagonal(self, offset=0, axis1=0, axis2=1): """Return specified diagonals.""" - return Array(self.value.diagonal(offset=offset, axis1=axis1, axis2=axis2)) + return _return(self.value.diagonal(offset=offset, axis1=axis1, axis2=axis2)) def dot(self, b): """Dot product of two arrays.""" - return Array(self.value.dot(b.value if isinstance(b, Array) else b)) + return _return(self.value.dot(_as_jax_array_(b))) def fill(self, value): """Fill the array with a scalar value.""" @@ -592,8 +613,8 @@ def fill(self, value): raise MathError(msg) self._value = jnp.ones_like(self.value) * value - def flatten(self, order='C'): - return Array(self.value.flatten(order=order)) + def flatten(self): + return _return(self.value.flatten()) def item(self, *args): """Copy an element of an array to a standard Python scalar and return it.""" @@ -602,31 +623,31 @@ def item(self, *args): def max(self, axis=None, keepdims=False, *args, **kwargs): """Return the maximum along a given axis.""" res = self.value.max(axis=axis, keepdims=keepdims, *args, **kwargs) - return res if (axis is None or keepdims) else Array(res) + return res if (axis is None or keepdims) else _return(res) def mean(self, axis=None, dtype=None, keepdims=False, *args, **kwargs): """Returns the average of the array elements along given axis.""" res = self.value.mean(axis=axis, dtype=dtype, keepdims=keepdims, *args, **kwargs) - return res if (axis is None or keepdims) else Array(res) + return _return(res) def min(self, axis=None, keepdims=False, *args, **kwargs): """Return the minimum along a given axis.""" res = self.value.min(axis=axis, keepdims=keepdims, *args, **kwargs) - return res if (axis is None or keepdims) else Array(res) + return _return(res) def nonzero(self): """Return the indices of the elements that are non-zero.""" - return tuple(Array(a) for a in self.value.nonzero()) + return tuple(_return(a) for a in self.value.nonzero()) def prod(self, axis=None, dtype=None, keepdims=False, initial=1, where=True): """Return the product of the array elements over the given axis.""" res = self.value.prod(axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) - return res if (axis is None or keepdims) else Array(res) + return _return(res) def ptp(self, axis=None, keepdims=False): """Peak to peak (maximum - minimum) value along a given axis.""" r = self.value.ptp(axis=axis, keepdims=keepdims) - return r if (axis is None or keepdims) else Array(r) + return _return(r) def put(self, indices, values): """Replaces specified elements of an array with given values. @@ -642,15 +663,15 @@ def put(self, indices, values): def ravel(self, order=None): """Return a flattened array.""" - return Array(self.value.ravel(order=order)) + return _return(self.value.ravel(order=order)) def repeat(self, repeats, axis=None): """Repeat elements of an array.""" - return Array(self.value.repeat(repeats=repeats, axis=axis)) + return _return(self.value.repeat(repeats=repeats, axis=axis)) def reshape(self, *shape, order='C'): """Returns an array containing the same data with a new shape.""" - return Array(self.value.reshape(*shape, order=order)) + return _return(self.value.reshape(*shape, order=order)) def resize(self, new_shape): """Change shape and size of array in-place.""" @@ -658,7 +679,7 @@ def resize(self, new_shape): def round(self, decimals=0): """Return ``a`` with each element rounded to the given number of decimals.""" - return Array(self.value.round(decimals=decimals)) + return _return(self.value.round(decimals=decimals)) def searchsorted(self, v, side='left', sorter=None): """Find indices where elements should be inserted to maintain order. @@ -693,8 +714,7 @@ def searchsorted(self, v, side='left', sorter=None): indices : array of ints Array of insertion points with the same shape as `v`. """ - v = v.value if isinstance(v, Array) else v - return Array(self.value.searchsorted(v=v, side=side, sorter=sorter)) + return _return(self.value.searchsorted(v=_as_jax_array_(v), side=side, sorter=sorter)) def sort(self, axis=-1, kind='quicksort', order=None): """Sort an array in-place. @@ -722,7 +742,7 @@ def sort(self, axis=-1, kind='quicksort', order=None): def squeeze(self, axis=None): """Remove axes of length one from ``a``.""" - return Array(self.value.squeeze(axis=axis)) + return _return(self.value.squeeze(axis=axis)) def std(self, axis=None, dtype=None, ddof=0, keepdims=False): """Compute the standard deviation along the specified axis. @@ -764,16 +784,16 @@ def std(self, axis=None, dtype=None, ddof=0, keepdims=False): otherwise return a reference to the output array. """ r = self.value.std(axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims) - return r if (axis is None or keepdims) else Array(r) + return _return(r) def sum(self, axis=None, dtype=None, keepdims=False, initial=0, where=True): """Return the sum of the array elements over the given axis.""" res = self.value.sum(axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) - return res if (axis is None or keepdims) else Array(res) + return _return(res) def swapaxes(self, axis1, axis2): """Return a view of the array with `axis1` and `axis2` interchanged.""" - return Array(self.value.swapaxes(axis1, axis2)) + return _return(self.value.swapaxes(axis1, axis2)) def split(self, indices_or_sections, axis=0): """Split an array into multiple sub-arrays as views into ``ary``. @@ -803,12 +823,11 @@ def split(self, indices_or_sections, axis=0): sub-arrays : list of ndarrays A list of sub-arrays as views into `ary`. """ - return [Array(a) for a in self.value.split(indices_or_sections, axis=axis)] + return [_return(a) for a in self.value.split(indices_or_sections, axis=axis)] def take(self, indices, axis=None, mode=None): """Return an array formed from the elements of a at the given indices.""" - indices = indices.value if isinstance(indices, Array) else indices - return Array(self.value.take(indices=indices, axis=axis, mode=mode)) + return _return(self.value.take(indices=_as_jax_array_(indices), axis=axis, mode=mode)) def tobytes(self, order='C'): """Construct Python bytes containing the raw data bytes in the array. @@ -816,7 +835,7 @@ def tobytes(self, order='C'): Constructs Python bytes showing a copy of the raw contents of data memory. The bytes object is produced in C-order by default. This behavior is controlled by the ``order`` parameter.""" - return Array(self.value.tobytes(order=order)) + return _return(self.value.tobytes(order=order)) def tolist(self): """Return the array as an ``a.ndim``-levels deep nested list of Python scalars. @@ -832,7 +851,7 @@ def tolist(self): def trace(self, offset=0, axis1=0, axis2=1, dtype=None): """Return the sum along diagonals of the array.""" - return Array(self.value.trace(offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)) + return _return(self.value.trace(offset=offset, axis1=axis1, axis2=axis2, dtype=dtype)) def transpose(self, *axes): """Returns a view of the array with axes transposed. @@ -864,7 +883,7 @@ def transpose(self, *axes): out : ndarray View of `a`, with axes suitably permuted. """ - return Array(self.value.transpose(*axes)) + return _return(self.value.transpose(*axes)) def tile(self, reps): """Construct an array by repeating A the number of times given by reps. @@ -895,17 +914,16 @@ def tile(self, reps): c : ndarray The tiled output array. """ - reps = reps.value if isinstance(reps, Array) else reps - return Array(self.value.tile(reps)) + return _return(self.value.tile(_as_jax_array_(reps))) def var(self, axis=None, dtype=None, ddof=0, keepdims=False): """Returns the variance of the array elements, along given axis.""" r = self.value.var(axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims) - return r if (axis is None or keepdims) else Array(r) + return _return(r) def view(self, dtype=None, *args, **kwargs): """New view of array with the same data.""" - return Array(self.value.view(dtype=dtype, *args, **kwargs)) + return _return(self.value.view(dtype=dtype, *args, **kwargs)) # ------------------ # NumPy support @@ -1615,7 +1633,7 @@ class VariableView(Variable): >>> origin Variable([1. , 1. , 0.5482849, 0.6564884, 0.8446237], dtype=float32) >>> view + 10 - DeviceArray([11., 11.], dtype=float32) + Array([11., 11.], dtype=float32) >>> view *= 10 VariableView([10., 10.], dtype=float32) diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py index 9608ec951..79f4e78b4 100644 --- a/brainpy/_src/math/random.py +++ b/brainpy/_src/math/random.py @@ -14,8 +14,7 @@ from jax.tree_util import register_pytree_node from brainpy.check import jit_error_checking -from ._utils import _return -from .ndarray import Array, Variable +from .ndarray import Array, Variable, _return __all__ = [ 'RandomState', 'Generator', 'DEFAULT', diff --git a/brainpy/math/others.py b/brainpy/math/others.py index 93e2f1a46..5e72756ef 100644 --- a/brainpy/math/others.py +++ b/brainpy/math/others.py @@ -3,6 +3,6 @@ from brainpy._src.math.others import ( shared_args_over_time as shared_args_over_time, ) -from brainpy._src.math._utils import ( +from brainpy._src.math.ndarray import ( npfun_returns_bparray as npfun_returns_bparray ) From f86fc31a896edbed0ce120164cdb6bbab7455360 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sat, 21 Jan 2023 16:37:43 +0800 Subject: [PATCH 2/8] nonbatch_shape --- brainpy/_src/dyn/base.py | 2 +- brainpy/_src/math/ndarray.py | 12 +++++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/brainpy/_src/dyn/base.py b/brainpy/_src/dyn/base.py index cd79a38f3..0f38324d5 100644 --- a/brainpy/_src/dyn/base.py +++ b/brainpy/_src/dyn/base.py @@ -1402,7 +1402,7 @@ def __init__( raise UnsupportedError('Should provide varshape when the target does ' f'not define its {SLICE_VARS}') all_vars = target.vars(level=1, include_self=True, method='relative') - all_vars = {k: v for k, v in all_vars.items() if v.shape_nb == varshape} + all_vars = {k: v for k, v in all_vars.items() if v.nobatch_shape == varshape} else: all_vars = {} for var_str in getattr(self.target, SLICE_VARS): diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py index f715f843a..56b53e023 100644 --- a/brainpy/_src/math/ndarray.py +++ b/brainpy/_src/math/ndarray.py @@ -1044,12 +1044,14 @@ def __init__( f'but the batch axis is set to be {batch_axis}.') @property - def shape_nb(self) -> TupleType[int, ...]: + def nobatch_shape(self) -> TupleType[int, ...]: """Shape without batch axis.""" - shape = list(self.value.shape) if self.batch_axis is not None: + shape = list(self.value.shape) shape.pop(self.batch_axis) - return tuple(shape) + return tuple(shape) + else: + return self.shape @property def batch_axis(self) -> Optional[int]: @@ -1400,8 +1402,8 @@ def dot(self, b): """Dot product of two arrays.""" return self.value.dot(b.value if isinstance(b, Array) else b) - def flatten(self, order='C'): - return self.value.flatten(order=order) + def flatten(self): + return self.value.flatten() def item(self, *args): """Copy an element of an array to a standard Python scalar and return it.""" From 2cf6fdc89abf10cad70ccabb103b044952987d7d Mon Sep 17 00:00:00 2001 From: chaoming Date: Wed, 25 Jan 2023 11:28:42 +0800 Subject: [PATCH 3/8] change analysis access --- brainpy/_src/analysis/constants.py | 51 ++++++++++++++++++++++-------- brainpy/_src/analysis/plotstyle.py | 1 + brainpy/analysis.py | 9 ++---- 3 files changed, 41 insertions(+), 20 deletions(-) diff --git a/brainpy/_src/analysis/constants.py b/brainpy/_src/analysis/constants.py index e9691cca5..af9bd4397 100644 --- a/brainpy/_src/analysis/constants.py +++ b/brainpy/_src/analysis/constants.py @@ -4,9 +4,47 @@ __all__ = [ 'CONTINUOUS', 'DISCRETE', + + 'F_vmap_fx', + 'F_vmap_fy', + 'F_vmap_brentq_fx', + 'F_vmap_brentq_fy', + 'F_vmap_fp_aux', + 'F_vmap_fp_opt', + 'F_vmap_dfxdx', + 'F_fx', + 'F_fy', + 'F_fz', + 'F_dfxdx', + 'F_dfxdy', + 'F_dfydx', + 'F_dfydy', + 'F_jacobian', + 'F_vmap_jacobian', + 'F_fixed_point_aux', + 'F_fixed_point_opt', + 'F_x_by_y', + 'F_y_by_x', + 'F_y_convert', + 'F_x_convert', + 'F_int_x', + 'F_int_y', + 'x_by_y', + 'y_by_x', + 'y_by_x_in_fy', + 'y_by_x_in_fx', + 'x_by_y_in_fx', + 'x_by_y_in_fy', + 'F_y_by_x_in_fy', + 'F_x_by_y_in_fy', + 'F_y_by_x_in_fx', + 'F_x_by_y_in_fx', + 'fx_nullcline_points', + 'fy_nullcline_points', ] + CONTINUOUS = 'continuous' DISCRETE = 'discrete' @@ -26,13 +64,8 @@ F_dfydy = 'F_dfydy' F_jacobian = 'F_jacobian' F_vmap_jacobian = 'F_vmap_jacobian' -F_fixed_point = 'F_fixed_point' F_fixed_point_aux = 'F_fixed_point_aux' F_fixed_point_opt = 'F_fixed_point_opt' -F_fx_nullcline_by_opt = 'F_fx_nullcline_by_opt' -F_fy_nullcline_by_opt = 'F_fy_nullcline_by_opt' -F_x_in_all = 'F_x_in_all' -F_y_in_all = 'F_y_in_all' F_x_by_y = 'F_x_by_y' F_y_by_x = 'F_y_by_x' F_y_convert = 'F_y_convert' @@ -52,13 +85,5 @@ F_x_by_y_in_fx = 'F[fx::x=f(y)]' fx_nullcline_points = 'fx_nullcline_points' fy_nullcline_points = 'fy_nullcline_points' -sympy_failed = 'sympy_failed' -sympy_success = 'sympy_success' -sympy_escape = 'sympy_escape' -sympy_timeout = 'sympy_timeout' -fx_sign = 'fx_sign' -fy_sign = 'fy_sign' -par_eval_parallel = 'par_eval_parallel' -par_eval_iter = 'par_eval_iter' prefix = '\t' diff --git a/brainpy/_src/analysis/plotstyle.py b/brainpy/_src/analysis/plotstyle.py index 3a81735c1..e59ea6435 100644 --- a/brainpy/_src/analysis/plotstyle.py +++ b/brainpy/_src/analysis/plotstyle.py @@ -4,6 +4,7 @@ __all__ = [ 'plot_schema', 'set_plot_schema', + 'set_markersize', ] from .stability import (CENTER_MANIFOLD, SADDLE_NODE, STABLE_POINT_1D, diff --git a/brainpy/analysis.py b/brainpy/analysis.py index d00719c2b..1105f73e3 100644 --- a/brainpy/analysis.py +++ b/brainpy/analysis.py @@ -17,11 +17,6 @@ SlowPointFinder as SlowPointFinder, ) -from brainpy._src.analysis.plotstyle import ( - set_plot_schema as set_plot_schema, -) +from brainpy._src.analysis import plotstyle, stability, constants +C = constants -from brainpy._src.analysis.constants import ( - CONTINUOUS as CONTINUOUS, - DISCRETE as DISCRETE, -) From b8b083ebe127a45ff34724532c8570879cb9d60e Mon Sep 17 00:00:00 2001 From: chaoming Date: Wed, 25 Jan 2023 11:44:25 +0800 Subject: [PATCH 4/8] `brainpy.dyn` module compatible --- brainpy/__init__.py | 359 +--------------------------- brainpy/_src/dyn/synapses/compat.py | 50 +++- brainpy/dyn.py | 4 + 3 files changed, 52 insertions(+), 361 deletions(-) diff --git a/brainpy/__init__.py b/brainpy/__init__.py index a058532e5..521c038a7 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -207,6 +207,7 @@ dyn.__dict__['LoopOverTime'] = LoopOverTime dyn.__dict__['DSRunner'] = DSRunner +# neurons dyn.__dict__['HH'] = neurons.HH dyn.__dict__['MorrisLecar'] = neurons.MorrisLecar dyn.__dict__['PinskyRinzelModel'] = neurons.PinskyRinzelModel @@ -225,6 +226,7 @@ dyn.__dict__['PoissonGroup'] = neurons.PoissonGroup dyn.__dict__['OUProcess'] = neurons.OUProcess +# synapses from brainpy._src.dyn.synapses import compat dyn.__dict__['DeltaSynapse'] = compat.DeltaSynapse dyn.__dict__['ExpCUBA'] = compat.ExpCUBA @@ -236,360 +238,3 @@ dyn.__dict__['NMDA'] = compat.NMDA del compat - -# import brainpy._src.math.arraycompatible as bm -# math.__dict__['full'] = bm.full -# math.__dict__['full_like'] = bm.full_like -# math.__dict__['eye'] = bm.eye -# math.__dict__['identity'] = bm.identity -# math.__dict__['diag'] = bm.diag -# math.__dict__['tri'] = bm.tri -# math.__dict__['tril'] = bm.tril -# math.__dict__['triu'] = bm.triu -# math.__dict__['real'] = bm.real -# math.__dict__['imag'] = bm.imag -# math.__dict__['conj'] = bm.conj -# math.__dict__['conjugate'] = bm.conjugate -# math.__dict__['ndim'] = bm.ndim -# math.__dict__['isreal'] = bm.isreal -# math.__dict__['isscalar'] = bm.isscalar -# math.__dict__['add'] = bm.add -# math.__dict__['reciprocal'] = bm.reciprocal -# math.__dict__['negative'] = bm.negative -# math.__dict__['positive'] = bm.positive -# math.__dict__['multiply'] = bm.multiply -# math.__dict__['divide'] = bm.divide -# math.__dict__['power'] = bm.power -# math.__dict__['subtract'] = bm.subtract -# math.__dict__['true_divide'] = bm.true_divide -# math.__dict__['floor_divide'] = bm.floor_divide -# math.__dict__['float_power'] = bm.float_power -# math.__dict__['fmod'] = bm.fmod -# math.__dict__['mod'] = bm.mod -# math.__dict__['modf'] = bm.modf -# math.__dict__['divmod'] = bm.divmod -# math.__dict__['remainder'] = bm.remainder -# math.__dict__['abs'] = bm.abs -# math.__dict__['exp'] = bm.exp -# math.__dict__['exp2'] = bm.exp2 -# math.__dict__['expm1'] = bm.expm1 -# math.__dict__['log'] = bm.log -# math.__dict__['log10'] = bm.log10 -# math.__dict__['log1p'] = bm.log1p -# math.__dict__['log2'] = bm.log2 -# math.__dict__['logaddexp'] = bm.logaddexp -# math.__dict__['logaddexp2'] = bm.logaddexp2 -# math.__dict__['lcm'] = bm.lcm -# math.__dict__['gcd'] = bm.gcd -# math.__dict__['arccos'] = bm.arccos -# math.__dict__['arccosh'] = bm.arccosh -# math.__dict__['arcsin'] = bm.arcsin -# math.__dict__['arcsinh'] = bm.arcsinh -# math.__dict__['arctan'] = bm.arctan -# math.__dict__['arctan2'] = bm.arctan2 -# math.__dict__['arctanh'] = bm.arctanh -# math.__dict__['cos'] = bm.cos -# math.__dict__['cosh'] = bm.cosh -# math.__dict__['sin'] = bm.sin -# math.__dict__['sinc'] = bm.sinc -# math.__dict__['sinh'] = bm.sinh -# math.__dict__['tan'] = bm.tan -# math.__dict__['tanh'] = bm.tanh -# math.__dict__['deg2rad'] = bm.deg2rad -# math.__dict__['hypot'] = bm.hypot -# math.__dict__['rad2deg'] = bm.rad2deg -# math.__dict__['degrees'] = bm.degrees -# math.__dict__['radians'] = bm.radians -# math.__dict__['round'] = bm.round -# math.__dict__['around'] = bm.around -# math.__dict__['round_'] = bm.round_ -# math.__dict__['rint'] = bm.rint -# math.__dict__['floor'] = bm.floor -# math.__dict__['ceil'] = bm.ceil -# math.__dict__['trunc'] = bm.trunc -# math.__dict__['fix'] = bm.fix -# math.__dict__['prod'] = bm.prod -# math.__dict__['sum'] = bm.sum -# math.__dict__['diff'] = bm.diff -# math.__dict__['median'] = bm.median -# math.__dict__['nancumprod'] = bm.nancumprod -# math.__dict__['nancumsum'] = bm.nancumsum -# math.__dict__['nanprod'] = bm.nanprod -# math.__dict__['nansum'] = bm.nansum -# math.__dict__['cumprod'] = bm.cumprod -# math.__dict__['cumsum'] = bm.cumsum -# math.__dict__['ediff1d'] = bm.ediff1d -# math.__dict__['cross'] = bm.cross -# math.__dict__['trapz'] = bm.trapz -# math.__dict__['isfinite'] = bm.isfinite -# math.__dict__['isinf'] = bm.isinf -# math.__dict__['isnan'] = bm.isnan -# math.__dict__['signbit'] = bm.signbit -# math.__dict__['copysign'] = bm.copysign -# math.__dict__['nextafter'] = bm.nextafter -# math.__dict__['ldexp'] = bm.ldexp -# math.__dict__['frexp'] = bm.frexp -# math.__dict__['convolve'] = bm.convolve -# math.__dict__['sqrt'] = bm.sqrt -# math.__dict__['cbrt'] = bm.cbrt -# math.__dict__['square'] = bm.square -# math.__dict__['absolute'] = bm.absolute -# math.__dict__['fabs'] = bm.fabs -# math.__dict__['sign'] = bm.sign -# math.__dict__['heaviside'] = bm.heaviside -# math.__dict__['maximum'] = bm.maximum -# math.__dict__['minimum'] = bm.minimum -# math.__dict__['fmax'] = bm.fmax -# math.__dict__['fmin'] = bm.fmin -# math.__dict__['interp'] = bm.interp -# math.__dict__['clip'] = bm.clip -# math.__dict__['angle'] = bm.angle -# math.__dict__['bitwise_and'] = bm.bitwise_and -# math.__dict__['bitwise_not'] = bm.bitwise_not -# math.__dict__['bitwise_or'] = bm.bitwise_or -# math.__dict__['bitwise_xor'] = bm.bitwise_xor -# math.__dict__['invert'] = bm.invert -# math.__dict__['left_shift'] = bm.left_shift -# math.__dict__['right_shift'] = bm.right_shift -# math.__dict__['equal'] = bm.equal -# math.__dict__['not_equal'] = bm.not_equal -# math.__dict__['greater'] = bm.greater -# math.__dict__['greater_equal'] = bm.greater_equal -# math.__dict__['less'] = bm.less -# math.__dict__['less_equal'] = bm.less_equal -# math.__dict__['array_equal'] = bm.array_equal -# math.__dict__['isclose'] = bm.isclose -# math.__dict__['allclose'] = bm.allclose -# math.__dict__['logical_and'] = bm.logical_and -# math.__dict__['logical_not'] = bm.logical_not -# math.__dict__['logical_or'] = bm.logical_or -# math.__dict__['logical_xor'] = bm.logical_xor -# math.__dict__['all'] = bm.all -# math.__dict__['any'] = bm.any -# math.__dict__['alltrue'] = bm.alltrue -# math.__dict__['sometrue'] = bm.sometrue -# math.__dict__['shape'] = bm.shape -# math.__dict__['size'] = bm.size -# math.__dict__['reshape'] = bm.reshape -# math.__dict__['ravel'] = bm.ravel -# math.__dict__['moveaxis'] = bm.moveaxis -# math.__dict__['transpose'] = bm.transpose -# math.__dict__['swapaxes'] = bm.swapaxes -# math.__dict__['concatenate'] = bm.concatenate -# math.__dict__['stack'] = bm.stack -# math.__dict__['vstack'] = bm.vstack -# math.__dict__['hstack'] = bm.hstack -# math.__dict__['dstack'] = bm.dstack -# math.__dict__['column_stack'] = bm.column_stack -# math.__dict__['split'] = bm.split -# math.__dict__['dsplit'] = bm.dsplit -# math.__dict__['hsplit'] = bm.hsplit -# math.__dict__['vsplit'] = bm.vsplit -# math.__dict__['tile'] = bm.tile -# math.__dict__['repeat'] = bm.repeat -# math.__dict__['unique'] = bm.unique -# math.__dict__['append'] = bm.append -# math.__dict__['flip'] = bm.flip -# math.__dict__['fliplr'] = bm.fliplr -# math.__dict__['flipud'] = bm.flipud -# math.__dict__['roll'] = bm.roll -# math.__dict__['atleast_1d'] = bm.atleast_1d -# math.__dict__['atleast_2d'] = bm.atleast_2d -# math.__dict__['atleast_3d'] = bm.atleast_3d -# math.__dict__['expand_dims'] = bm.expand_dims -# math.__dict__['squeeze'] = bm.squeeze -# math.__dict__['sort'] = bm.sort -# math.__dict__['argsort'] = bm.argsort -# math.__dict__['argmax'] = bm.argmax -# math.__dict__['argmin'] = bm.argmin -# math.__dict__['argwhere'] = bm.argwhere -# math.__dict__['nonzero'] = bm.nonzero -# math.__dict__['flatnonzero'] = bm.flatnonzero -# math.__dict__['where'] = bm.where -# math.__dict__['searchsorted'] = bm.searchsorted -# math.__dict__['extract'] = bm.extract -# math.__dict__['count_nonzero'] = bm.count_nonzero -# math.__dict__['max'] = bm.max -# math.__dict__['min'] = bm.min -# math.__dict__['amax'] = bm.amax -# math.__dict__['amin'] = bm.amin -# math.__dict__['array_split'] = bm.array_split -# math.__dict__['meshgrid'] = bm.meshgrid -# math.__dict__['vander'] = bm.vander -# math.__dict__['nonzero'] = bm.nonzero -# math.__dict__['where'] = bm.where -# math.__dict__['tril_indices'] = bm.tril_indices -# math.__dict__['tril_indices_from'] = bm.tril_indices_from -# math.__dict__['triu_indices'] = bm.triu_indices -# math.__dict__['triu_indices_from'] = bm.triu_indices_from -# math.__dict__['take'] = bm.take -# math.__dict__['select'] = bm.select -# math.__dict__['nanmin'] = bm.nanmin -# math.__dict__['nanmax'] = bm.nanmax -# math.__dict__['ptp'] = bm.ptp -# math.__dict__['percentile'] = bm.percentile -# math.__dict__['nanpercentile'] = bm.nanpercentile -# math.__dict__['quantile'] = bm.quantile -# math.__dict__['nanquantile'] = bm.nanquantile -# math.__dict__['median'] = bm.median -# math.__dict__['average'] = bm.average -# math.__dict__['mean'] = bm.mean -# math.__dict__['std'] = bm.std -# math.__dict__['var'] = bm.var -# math.__dict__['nanmedian'] = bm.nanmedian -# math.__dict__['nanmean'] = bm.nanmean -# math.__dict__['nanstd'] = bm.nanstd -# math.__dict__['nanvar'] = bm.nanvar -# math.__dict__['corrcoef'] = bm.corrcoef -# math.__dict__['correlate'] = bm.correlate -# math.__dict__['cov'] = bm.cov -# math.__dict__['histogram'] = bm.histogram -# math.__dict__['bincount'] = bm.bincount -# math.__dict__['digitize'] = bm.digitize -# math.__dict__['bartlett'] = bm.bartlett -# math.__dict__['blackman'] = bm.blackman -# math.__dict__['hamming'] = bm.hamming -# math.__dict__['hanning'] = bm.hanning -# math.__dict__['kaiser'] = bm.kaiser -# math.__dict__['e'] = bm.e -# math.__dict__['pi'] = bm.pi -# math.__dict__['inf'] = bm.inf -# math.__dict__['dot'] = bm.dot -# math.__dict__['vdot'] = bm.vdot -# math.__dict__['inner'] = bm.inner -# math.__dict__['outer'] = bm.outer -# math.__dict__['kron'] = bm.kron -# math.__dict__['matmul'] = bm.matmul -# math.__dict__['trace'] = bm.trace -# math.__dict__['dtype'] = bm.dtype -# math.__dict__['finfo'] = bm.finfo -# math.__dict__['iinfo'] = bm.iinfo -# math.__dict__['uint8'] = bm.uint8 -# math.__dict__['uint16'] = bm.uint16 -# math.__dict__['uint32'] = bm.uint32 -# math.__dict__['uint64'] = bm.uint64 -# math.__dict__['int8'] = bm.int8 -# math.__dict__['int16'] = bm.int16 -# math.__dict__['int32'] = bm.int32 -# math.__dict__['int64'] = bm.int64 -# math.__dict__['float16'] = bm.float16 -# math.__dict__['float32'] = bm.float32 -# math.__dict__['float64'] = bm.float64 -# math.__dict__['complex64'] = bm.complex64 -# math.__dict__['complex128'] = bm.complex128 -# math.__dict__['product'] = bm.product -# math.__dict__['row_stack'] = bm.row_stack -# math.__dict__['apply_over_axes'] = bm.apply_over_axes -# math.__dict__['apply_along_axis'] = bm.apply_along_axis -# math.__dict__['array_equiv'] = bm.array_equiv -# math.__dict__['array_repr'] = bm.array_repr -# math.__dict__['array_str'] = bm.array_str -# math.__dict__['block'] = bm.block -# math.__dict__['broadcast_arrays'] = bm.broadcast_arrays -# math.__dict__['broadcast_shapes'] = bm.broadcast_shapes -# math.__dict__['broadcast_to'] = bm.broadcast_to -# math.__dict__['compress'] = bm.compress -# math.__dict__['cumproduct'] = bm.cumproduct -# math.__dict__['diag_indices'] = bm.diag_indices -# math.__dict__['diag_indices_from'] = bm.diag_indices_from -# math.__dict__['diagflat'] = bm.diagflat -# math.__dict__['diagonal'] = bm.diagonal -# math.__dict__['einsum'] = bm.einsum -# math.__dict__['einsum_path'] = bm.einsum_path -# math.__dict__['geomspace'] = bm.geomspace -# math.__dict__['gradient'] = bm.gradient -# math.__dict__['histogram2d'] = bm.histogram2d -# math.__dict__['histogram_bin_edges'] = bm.histogram_bin_edges -# math.__dict__['histogramdd'] = bm.histogramdd -# math.__dict__['i0'] = bm.i0 -# math.__dict__['in1d'] = bm.in1d -# math.__dict__['indices'] = bm.indices -# math.__dict__['insert'] = bm.insert -# math.__dict__['intersect1d'] = bm.intersect1d -# math.__dict__['iscomplex'] = bm.iscomplex -# math.__dict__['isin'] = bm.isin -# math.__dict__['ix_'] = bm.ix_ -# math.__dict__['lexsort'] = bm.lexsort -# math.__dict__['load'] = bm.load -# math.__dict__['save'] = bm.save -# math.__dict__['savez'] = bm.savez -# math.__dict__['mask_indices'] = bm.mask_indices -# math.__dict__['msort'] = bm.msort -# math.__dict__['nan_to_num'] = bm.nan_to_num -# math.__dict__['nanargmax'] = bm.nanargmax -# math.__dict__['setdiff1d'] = bm.setdiff1d -# math.__dict__['nanargmin'] = bm.nanargmin -# math.__dict__['pad'] = bm.pad -# math.__dict__['poly'] = bm.poly -# math.__dict__['polyadd'] = bm.polyadd -# math.__dict__['polyder'] = bm.polyder -# math.__dict__['polyfit'] = bm.polyfit -# math.__dict__['polyint'] = bm.polyint -# math.__dict__['polymul'] = bm.polymul -# math.__dict__['polysub'] = bm.polysub -# math.__dict__['polyval'] = bm.polyval -# math.__dict__['resize'] = bm.resize -# math.__dict__['rollaxis'] = bm.rollaxis -# math.__dict__['roots'] = bm.roots -# math.__dict__['rot90'] = bm.rot90 -# math.__dict__['setxor1d'] = bm.setxor1d -# math.__dict__['tensordot'] = bm.tensordot -# math.__dict__['trim_zeros'] = bm.trim_zeros -# math.__dict__['union1d'] = bm.union1d -# math.__dict__['unravel_index'] = bm.unravel_index -# math.__dict__['unwrap'] = bm.unwrap -# math.__dict__['take_along_axis'] = bm.take_along_axis -# math.__dict__['can_cast'] = bm.can_cast -# math.__dict__['choose'] = bm.choose -# math.__dict__['copy'] = bm.copy -# math.__dict__['frombuffer'] = bm.frombuffer -# math.__dict__['fromfile'] = bm.fromfile -# math.__dict__['fromfunction'] = bm.fromfunction -# math.__dict__['fromiter'] = bm.fromiter -# math.__dict__['fromstring'] = bm.fromstring -# math.__dict__['get_printoptions'] = bm.get_printoptions -# math.__dict__['iscomplexobj'] = bm.iscomplexobj -# math.__dict__['isneginf'] = bm.isneginf -# math.__dict__['isposinf'] = bm.isposinf -# math.__dict__['isrealobj'] = bm.isrealobj -# math.__dict__['issubdtype'] = bm.issubdtype -# math.__dict__['issubsctype'] = bm.issubsctype -# math.__dict__['iterable'] = bm.iterable -# math.__dict__['packbits'] = bm.packbits -# math.__dict__['piecewise'] = bm.piecewise -# math.__dict__['printoptions'] = bm.printoptions -# math.__dict__['set_printoptions'] = bm.set_printoptions -# math.__dict__['promote_types'] = bm.promote_types -# math.__dict__['ravel_multi_index'] = bm.ravel_multi_index -# math.__dict__['result_type'] = bm.result_type -# math.__dict__['sort_complex'] = bm.sort_complex -# math.__dict__['unpackbits'] = bm.unpackbits -# math.__dict__['delete'] = bm.delete -# math.__dict__['add_docstring'] = bm.add_docstring -# math.__dict__['add_newdoc'] = bm.add_newdoc -# math.__dict__['add_newdoc_ufunc'] = bm.add_newdoc_ufunc -# math.__dict__['array2string'] = bm.array2string -# math.__dict__['asanyarray'] = bm.asanyarray -# math.__dict__['ascontiguousarray'] = bm.ascontiguousarray -# math.__dict__['asfarray'] = bm.asfarray -# math.__dict__['asscalar'] = bm.asscalar -# math.__dict__['common_type'] = bm.common_type -# math.__dict__['disp'] = bm.disp -# math.__dict__['genfromtxt'] = bm.genfromtxt -# math.__dict__['loadtxt'] = bm.loadtxt -# math.__dict__['info'] = bm.info -# math.__dict__['issubclass_'] = bm.issubclass_ -# math.__dict__['place'] = bm.place -# math.__dict__['polydiv'] = bm.polydiv -# math.__dict__['put'] = bm.put -# math.__dict__['putmask'] = bm.putmask -# math.__dict__['safe_eval'] = bm.safe_eval -# math.__dict__['savetxt'] = bm.savetxt -# math.__dict__['savez_compressed'] = bm.savez_compressed -# math.__dict__['show_config'] = bm.show_config -# math.__dict__['typename'] = bm.typename -# math.__dict__['copyto'] = bm.copyto -# math.__dict__['matrix'] = bm.matrix -# math.__dict__['asmatrix'] = bm.asmatrix -# math.__dict__['mat'] = bm.mat -# del bm diff --git a/brainpy/_src/dyn/synapses/compat.py b/brainpy/_src/dyn/synapses/compat.py index 1fb343ad5..e6d9f56df 100644 --- a/brainpy/_src/dyn/synapses/compat.py +++ b/brainpy/_src/dyn/synapses/compat.py @@ -1,14 +1,15 @@ # -*- coding: utf-8 -*- import warnings -from typing import Union, Dict, Callable +from typing import Union, Dict, Callable, Optional +import brainpy._src.math as bm from brainpy._src.connect import TwoEndConnector -from brainpy._src.dyn.base import NeuGroup -from brainpy._src.dyn.synouts import COBA, CUBA +from brainpy._src.dyn.base import NeuGroup, SynSTP +from brainpy._src.dyn.synouts import COBA, CUBA, MgBlock from brainpy._src.initialize import Initializer from brainpy.types import ArrayType -from .abstract_models import Delta, Exponential, DualExponential, NMDA +from .abstract_models import Delta, Exponential, DualExponential, NMDA as NewNMDA __all__ = [ 'DeltaSynapse', @@ -256,3 +257,44 @@ def __init__( tau_rise=tau_decay, method=method, name=name) + + +class NMDA(NewNMDA): + def __init__( + self, + pre: NeuGroup, + post: NeuGroup, + conn: Union[TwoEndConnector, ArrayType, Dict[str, ArrayType]], + E=0., + alpha=0.062, + beta=3.57, + cc_Mg=1.2, + stp: Optional[SynSTP] = None, + comp_method: str = 'dense', + g_max: Union[float, ArrayType, Initializer, Callable] = 0.15, + delay_step: Union[int, ArrayType, Initializer, Callable] = None, + tau_decay: Union[float, ArrayType] = 100., + a: Union[float, ArrayType] = 0.5, + tau_rise: Union[float, ArrayType] = 2., + method: str = 'exp_auto', + + # other parameters + name: str = None, + mode: bm.Mode = None, + stop_spike_gradient: bool = False, + ): + super(NMDA, self).__init__(pre=pre, + post=post, + conn=conn, + output=MgBlock(E=E, alpha=alpha, beta=beta, cc_Mg=cc_Mg), + stp=stp, + name=name, + mode=mode, + comp_method=comp_method, + g_max=g_max, + delay_step=delay_step, + tau_decay=tau_decay, + a=a, + tau_rise=tau_rise, + method=method, + stop_spike_gradient=stop_spike_gradient) diff --git a/brainpy/dyn.py b/brainpy/dyn.py index 1142c59c6..407631559 100644 --- a/brainpy/dyn.py +++ b/brainpy/dyn.py @@ -4,3 +4,7 @@ Deprecated. Use ``brainpy.xxx`` instead. """ + + + + From d2b2a157fe0524d03368aa6e26810710c4fd8682 Mon Sep 17 00:00:00 2001 From: chaoming Date: Wed, 25 Jan 2023 14:36:10 +0800 Subject: [PATCH 5/8] monitors and inputs compatibility --- brainpy/__init__.py | 2 +- brainpy/_src/dyn/runners.py | 3 +++ brainpy/_src/math/ndarray.py | 6 ++---- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/brainpy/__init__.py b/brainpy/__init__.py index 521c038a7..ad6fbe015 100644 --- a/brainpy/__init__.py +++ b/brainpy/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -__version__ = "2.3.2" +__version__ = "2.3.3" # fundamental supporting modules diff --git a/brainpy/_src/dyn/runners.py b/brainpy/_src/dyn/runners.py index 4a567de22..267209f9c 100644 --- a/brainpy/_src/dyn/runners.py +++ b/brainpy/_src/dyn/runners.py @@ -361,6 +361,7 @@ def __init__( warnings.warn('`fun_inputs` is deprecated since version 2.3.1. ' 'Define `fun_inputs` as `inputs` instead.', UserWarning) + self._fun_inputs = fun_inputs if callable(inputs): self._inputs = inputs else: @@ -551,6 +552,8 @@ def _step_func_monitor(self, shared): return res def _step_func_input(self, shared): + if self._fun_inputs is not None: + self._fun_inputs(shared) if callable(self._inputs): self._inputs(shared) else: diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py index 56b53e023..9ceccbd74 100644 --- a/brainpy/_src/math/ndarray.py +++ b/brainpy/_src/math/ndarray.py @@ -93,10 +93,8 @@ def _check_input_array(array): def _return(a): - if _return_bp_array: - if isinstance(a, jax.Array): - if a.ndim > 1: - return Array(a) + if _return_bp_array and isinstance(a, jax.Array) and a.ndim > 1: + return Array(a) return a From 237a4967b273cffe54fd7c532a67836885d35899 Mon Sep 17 00:00:00 2001 From: chaoming Date: Wed, 25 Jan 2023 14:44:26 +0800 Subject: [PATCH 6/8] fix bug --- brainpy/_src/initialize/tests/test_regular_inits.py | 1 - brainpy/_src/math/ndarray.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/brainpy/_src/initialize/tests/test_regular_inits.py b/brainpy/_src/initialize/tests/test_regular_inits.py index 85e6bfe91..8316fc661 100644 --- a/brainpy/_src/initialize/tests/test_regular_inits.py +++ b/brainpy/_src/initialize/tests/test_regular_inits.py @@ -20,7 +20,6 @@ def test_one_init(self): init = bp.init.OneInit(value=value) weights = init(size) assert weights.shape == size - assert isinstance(weights, bp.math.ndarray) assert (weights == value).all() diff --git a/brainpy/_src/math/ndarray.py b/brainpy/_src/math/ndarray.py index 9ceccbd74..451abc9c6 100644 --- a/brainpy/_src/math/ndarray.py +++ b/brainpy/_src/math/ndarray.py @@ -93,7 +93,7 @@ def _check_input_array(array): def _return(a): - if _return_bp_array and isinstance(a, jax.Array) and a.ndim > 1: + if _return_bp_array and isinstance(a, jax.Array) and a.ndim > 0: return Array(a) return a From 49b0dc8b2e1027823219247a3c9248303394ad4a Mon Sep 17 00:00:00 2001 From: chaoming Date: Wed, 25 Jan 2023 15:03:12 +0800 Subject: [PATCH 7/8] fix bugs --- brainpy/_src/analysis/constants.py | 3 --- brainpy/_src/math/random.py | 4 ++-- brainpy/analysis.py | 3 +++ 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/brainpy/_src/analysis/constants.py b/brainpy/_src/analysis/constants.py index af9bd4397..16898bac4 100644 --- a/brainpy/_src/analysis/constants.py +++ b/brainpy/_src/analysis/constants.py @@ -2,9 +2,6 @@ __all__ = [ - 'CONTINUOUS', - 'DISCRETE', - 'F_vmap_fx', 'F_vmap_fy', 'F_vmap_brentq_fx', diff --git a/brainpy/_src/math/random.py b/brainpy/_src/math/random.py index 79f4e78b4..2de4526ca 100644 --- a/brainpy/_src/math/random.py +++ b/brainpy/_src/math/random.py @@ -973,8 +973,8 @@ def wald(self, mean, scale, size=None, key=None): if size is None: size = lax.broadcast_shapes(jnp.shape(mean), jnp.shape(scale)) size = _size2shape(size) - sampled_chi2 = jnp.square(self.randn(*size)) - sampled_uniform = self.uniform(size=size, key=key) + sampled_chi2 = jnp.square(_as_jax_array(self.randn(*size))) + sampled_uniform = _as_jax_array(self.uniform(size=size, key=key)) # Wikipedia defines an intermediate x with the formula # x = loc + loc ** 2 * y / (2 * conc) - loc / (2 * conc) * sqrt(4 * loc * conc * y + loc ** 2 * y ** 2) # where y ~ N(0, 1)**2 (sampled_chi2 above) and conc is the concentration. diff --git a/brainpy/analysis.py b/brainpy/analysis.py index 1105f73e3..34f14f645 100644 --- a/brainpy/analysis.py +++ b/brainpy/analysis.py @@ -17,6 +17,9 @@ SlowPointFinder as SlowPointFinder, ) +from brainpy._src.analysis.constants import (CONTINUOUS as CONTINUOUS, + DISCRETE as DISCRETE) + from brainpy._src.analysis import plotstyle, stability, constants C = constants From 130c2d3dd2e76f097da1eb56d4c177eb94d46d74 Mon Sep 17 00:00:00 2001 From: chaoming Date: Wed, 25 Jan 2023 15:16:51 +0800 Subject: [PATCH 8/8] updates --- brainpy/_src/math/operators/pre_syn_post.py | 1 - 1 file changed, 1 deletion(-) diff --git a/brainpy/_src/math/operators/pre_syn_post.py b/brainpy/_src/math/operators/pre_syn_post.py index 2592b6926..c2aacca85 100644 --- a/brainpy/_src/math/operators/pre_syn_post.py +++ b/brainpy/_src/math/operators/pre_syn_post.py @@ -90,7 +90,6 @@ def pre2post_event_sum(events, out: ArrayType A tensor with the shape of ``post_num``. """ - warnings.warn('Please use ``brainpylib.event_ops.event_csr_matvec()`` instead.', UserWarning) indices, idnptr = pre2post events = as_jax(events) indices = as_jax(indices)