Skip to content

Commit

Permalink
common updates
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Aug 29, 2023
1 parent 68bdb2f commit fe90dde
Show file tree
Hide file tree
Showing 6 changed files with 180 additions and 41 deletions.
12 changes: 0 additions & 12 deletions brainpy/_add_deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,6 @@
dyn.__getattr__ = deprecation_getattr2('brainpy.dyn', dyn.__deprecations)


# dnn.__deprecations = {
# 'Layer': ('brainpy.dnn.Layer', 'brainpy.AnnLayer', AnnLayer),
# }
# dnn.__getattr__ = deprecation_getattr2('brainpy.dnn', dnn.__deprecations)


# layers.__deprecations = {
# 'Layer': ('brainpy.layers.Layer', 'brainpy.AnnLayer', AnnLayer),
# }
# layers.__getattr__ = deprecation_getattr2('brainpy.layers', layers.__deprecations)


connect.__deprecations = {
'one2one': ('brainpy.connect.one2one', 'brainpy.connect.One2One', connect.One2One),
'all2all': ('brainpy.connect.all2all', 'brainpy.connect.All2All', connect.All2All),
Expand Down
5 changes: 3 additions & 2 deletions brainpy/_src/dyn/others/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ def update(self, inp=None):
t = share.load('t')
dt = share.load('dt')
self.x.value = self.integral(self.x.value, t, dt)
if inp is not None:
self.x += inp
if inp is None: inp = 0.
inp = self.sum_inputs(self.x.value, init=inp)
self.x += inp
return self.x.value

def return_info(self):
Expand Down
76 changes: 49 additions & 27 deletions brainpy/_src/math/compat_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from .ndarray import Array, _as_jax_array_, _return, _check_out
from .compat_numpy import (
concatenate, shape
concatenate, shape, minimum, maximum,
)

__all__ = [
Expand All @@ -31,9 +31,10 @@
'arctan',
'atan2',
'atanh',
'clamp_max',
'clamp_min',
]


Tensor = Array
cat = concatenate

Expand Down Expand Up @@ -80,28 +81,28 @@ def flatten(input: Union[jax.Array, Array],
raise ValueError(f'start_dim {start_dim} is out of size.')
if end_dim < 0 or end_dim > ndim:
raise ValueError(f'end_dim {end_dim} is out of size.')
new_shape = shape[:start_dim] + (np.prod(shape[start_dim: end_dim], dtype=int), ) + shape[end_dim:]
new_shape = shape[:start_dim] + (np.prod(shape[start_dim: end_dim], dtype=int),) + shape[end_dim:]
return jnp.reshape(input, new_shape)


def unsqueeze(input: Union[jax.Array, Array], dim: int) -> Array:
"""Returns a new tensor with a dimension of size one inserted at the specified position.
The returned tensor shares the same underlying data with this tensor.
A dim value within the range [-input.dim() - 1, input.dim() + 1) can be used.
Negative dim will correspond to unsqueeze() applied at dim = dim + input.dim() + 1.
Parameters
----------
input: Array
The input Array
dim: int
The index at which to insert the singleton dimension
Returns
-------
out: Array
"""
input = _as_jax_array_(input)
return Array(jnp.expand_dims(input, dim))
"""Returns a new tensor with a dimension of size one inserted at the specified position.
The returned tensor shares the same underlying data with this tensor.
A dim value within the range [-input.dim() - 1, input.dim() + 1) can be used.
Negative dim will correspond to unsqueeze() applied at dim = dim + input.dim() + 1.
Parameters
----------
input: Array
The input Array
dim: int
The index at which to insert the singleton dimension
Returns
-------
out: Array
"""
input = _as_jax_array_(input)
return Array(jnp.expand_dims(input, dim))


# Math operations
Expand All @@ -115,10 +116,12 @@ def abs(input: Union[jax.Array, Array],
_check_out(out)
out.value = r


absolute = abs


def acos(input: Union[jax.Array, Array],
*, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]:
*, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
input = _as_jax_array_(input)
r = jnp.arccos(input)
if out is None:
Expand All @@ -127,10 +130,12 @@ def acos(input: Union[jax.Array, Array],
_check_out(out)
out.value = r


arccos = acos


def acosh(input: Union[jax.Array, Array],
*, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]:
*, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
input = _as_jax_array_(input)
r = jnp.arccosh(input)
if out is None:
Expand All @@ -139,8 +144,10 @@ def acosh(input: Union[jax.Array, Array],
_check_out(out)
out.value = r


arccosh = acosh


def add(input: Union[jax.Array, Array, jnp.number],
other: Union[jax.Array, Array, jnp.number],
*, alpha: Optional[jnp.number] = 1,
Expand All @@ -155,6 +162,7 @@ def add(input: Union[jax.Array, Array, jnp.number],
_check_out(out)
out.value = r


def addcdiv(input: Union[jax.Array, Array, jnp.number],
tensor1: Union[jax.Array, Array, jnp.number],
tensor2: Union[jax.Array, Array, jnp.number],
Expand All @@ -165,7 +173,8 @@ def addcdiv(input: Union[jax.Array, Array, jnp.number],
other = jnp.divide(tensor1, tensor2)
return add(input, other, alpha=value, out=out)

def addcmul(input: Union[jax.Array, Array, jnp.number],

def addcmul(input: Union[jax.Array, Array, jnp.number],
tensor1: Union[jax.Array, Array, jnp.number],
tensor2: Union[jax.Array, Array, jnp.number],
*, value: jnp.number = 1,
Expand All @@ -175,6 +184,7 @@ def addcmul(input: Union[jax.Array, Array, jnp.number],
other = jnp.multiply(tensor1, tensor2)
return add(input, other, alpha=value, out=out)


def angle(input: Union[jax.Array, Array, jnp.number],
*, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
input = _as_jax_array_(input)
Expand All @@ -185,8 +195,9 @@ def angle(input: Union[jax.Array, Array, jnp.number],
_check_out(out)
out.value = r


def asin(input: Union[jax.Array, Array],
*, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]:
*, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
input = _as_jax_array_(input)
r = jnp.arcsin(input)
if out is None:
Expand All @@ -195,10 +206,12 @@ def asin(input: Union[jax.Array, Array],
_check_out(out)
out.value = r


arcsin = asin


def asinh(input: Union[jax.Array, Array],
*, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]:
*, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
input = _as_jax_array_(input)
r = jnp.arcsinh(input)
if out is None:
Expand All @@ -207,10 +220,12 @@ def asinh(input: Union[jax.Array, Array],
_check_out(out)
out.value = r


arcsinh = asinh


def atan(input: Union[jax.Array, Array],
*, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]:
*, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
input = _as_jax_array_(input)
r = jnp.arctan(input)
if out is None:
Expand All @@ -219,8 +234,10 @@ def atan(input: Union[jax.Array, Array],
_check_out(out)
out.value = r


arctan = atan


def atanh(input: Union[jax.Array, Array],
*, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
input = _as_jax_array_(input)
Expand All @@ -231,8 +248,10 @@ def atanh(input: Union[jax.Array, Array],
_check_out(out)
out.value = r


arctanh = atanh


def atan2(input: Union[jax.Array, Array],
*, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]:
input = _as_jax_array_(input)
Expand All @@ -243,4 +262,7 @@ def atan2(input: Union[jax.Array, Array],
_check_out(out)
out.value = r

arctan2 = atan2

arctan2 = atan2
clamp_max = minimum
clamp_min = maximum
121 changes: 121 additions & 0 deletions brainpy/_src/visualization/animation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from collections import defaultdict
from typing import Dict, List

import matplotlib.pyplot as plt
from matplotlib.animation import ArtistAnimation
from matplotlib.artist import Artist
from matplotlib.figure import Figure

import brainpy.math as bm

__all__ = [
'animator',
]


def animator(data, fig, ax, num_steps=False, interval=40, cmap="plasma"):
"""Generate an animation by looping through the first dimension of a
sample of spiking data.
Time must be the first dimension of ``data``.
Example::
import matplotlib.pyplot as plt
# Index into a single sample from a minibatch
spike_data_sample = bm.random.rand(100, 28, 28)
print(spike_data_sample.shape)
>>> (100, 28, 28)
# Plot
fig, ax = plt.subplots()
anim = splt.animator(spike_data_sample, fig, ax)
HTML(anim.to_html5_video())
# Save as a gif
anim.save("spike_mnist.gif")
:param data: Data tensor for a single sample across time steps of
shape [num_steps x input_size]
:type data: torch.Tensor
:param fig: Top level container for all plot elements
:type fig: matplotlib.figure.Figure
:param ax: Contains additional figure elements and sets the coordinate
system. E.g.:
fig, ax = plt.subplots(facecolor='w', figsize=(12, 7))
:type ax: matplotlib.axes._subplots.AxesSubplot
:param num_steps: Number of time steps to plot. If not specified,
the number of entries in the first dimension
of ``data`` will automatically be used, defaults to ``False``
:type num_steps: int, optional
:param interval: Delay between frames in milliseconds, defaults to ``40``
:type interval: int, optional
:param cmap: color map, defaults to ``plasma``
:type cmap: string, optional
:return: animation to be displayed using ``matplotlib.pyplot.show()``
:rtype: FuncAnimation
"""

data = bm.as_numpy(data)
if not num_steps:
num_steps = data.shape[0]
camera = Camera(fig)
plt.axis("off")
# iterate over time and take a snapshot with celluloid
for step in range(
num_steps
): # im appears unused but is required by camera.snap()
im = ax.imshow(data[step], cmap=cmap) # noqa: F841
camera.snap()
anim = camera.animate(interval=interval)
return anim


class Camera:
"""Make animations easier."""

def __init__(self, figure: Figure) -> None:
"""Create camera from matplotlib figure."""
self._figure = figure
# need to keep track off artists for each axis
self._offsets: Dict[str, Dict[int, int]] = {
k: defaultdict(int)
for k in [
"collections",
"patches",
"lines",
"texts",
"artists",
"images",
]
}
self._photos: List[List[Artist]] = []

def snap(self) -> List[Artist]:
"""Capture current state of the figure."""
frame_artists: List[Artist] = []
for i, axis in enumerate(self._figure.axes):
if axis.legend_ is not None:
axis.add_artist(axis.legend_)
for name in self._offsets:
new_artists = getattr(axis, name)[self._offsets[name][i]:]
frame_artists += new_artists
self._offsets[name][i] += len(new_artists)
self._photos.append(frame_artists)
return frame_artists

def animate(self, *args, **kwargs) -> ArtistAnimation:
"""Animate the snapshots taken.
Uses matplotlib.animation.ArtistAnimation
Returns
-------
ArtistAnimation
"""
return ArtistAnimation(self._figure, self._photos, *args, **kwargs)
5 changes: 5 additions & 0 deletions brainpy/_src/visualization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,8 @@ def plot_style1(fontsize=22,
lw=1):
from .styles import plot_style1
plot_style1(fontsize=fontsize, axes_edgecolor=axes_edgecolor, figsize=figsize, lw=lw)

@staticmethod
def animator(data, fig, ax, num_steps=False, interval=40, cmap="plasma"):
from .animation import animator
return animator(data, fig, ax, num_steps=num_steps, interval=interval, cmap=cmap)
2 changes: 2 additions & 0 deletions brainpy/math/compat_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,6 @@
arctan as arctan,
atan2 as atan2,
atanh as atanh,
clamp_max,
clamp_min,
)

0 comments on commit fe90dde

Please sign in to comment.