From 68bdb2f6fff4bf5ae7b61a8144d29eec3768e9cb Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 29 Aug 2023 10:09:11 +0800 Subject: [PATCH 1/2] [encoding] upgrade encoding methods --- brainpy/_src/encoding/base.py | 9 +- brainpy/_src/encoding/stateful_encoding.py | 166 +++++++++------ brainpy/_src/encoding/stateless_encoding.py | 201 ++++++++++++++---- .../encoding/tests/test_stateless_encoding.py | 79 +++++++ brainpy/encoding.py | 3 +- 5 files changed, 350 insertions(+), 108 deletions(-) create mode 100644 brainpy/_src/encoding/tests/test_stateless_encoding.py diff --git a/brainpy/_src/encoding/base.py b/brainpy/_src/encoding/base.py index c85a0b98c..d2a53242d 100644 --- a/brainpy/_src/encoding/base.py +++ b/brainpy/_src/encoding/base.py @@ -10,8 +10,13 @@ class Encoder(BrainPyObject): """Base class for encoding rate values as spike trains.""" - def __call__(self, *args, **kwargs): - raise NotImplementedError def __repr__(self): return self.__class__.__name__ + + def single_step(self, *args, **kwargs): + raise NotImplementedError('Please implement the function for single step encoding.') + + def multi_steps(self, *args, **kwargs): + raise NotImplementedError('Encode implement the function for multiple-step encoding.') + diff --git a/brainpy/_src/encoding/stateful_encoding.py b/brainpy/_src/encoding/stateful_encoding.py index b40e4f427..c2b6ced2e 100644 --- a/brainpy/_src/encoding/stateful_encoding.py +++ b/brainpy/_src/encoding/stateful_encoding.py @@ -1,9 +1,10 @@ # -*- coding: utf-8 -*- import math -from typing import Union, Callable +from typing import Union, Callable, Optional import jax +import numpy as np import brainpy.math as bm from brainpy import check @@ -47,13 +48,10 @@ def __init__(self, weight_fun: Callable = None): super().__init__() - check.is_integer(num_phase, 'num_phase', min_bound=1) - check.is_float(min_val, 'min_val') - check.is_float(max_val, 'max_val') check.is_callable(weight_fun, 'weight_fun', allow_none=True) - self.num_phase = num_phase - self.min_val = min_val - self.max_val = max_val + self.num_phase = check.is_integer(num_phase, 'num_phase', min_bound=1) + self.min_val = check.is_float(min_val, 'min_val') + self.max_val = check.is_float(max_val, 'max_val') self.weight_fun = (lambda i: 2 ** (-(i % num_phase + 1))) if weight_fun is None else weight_fun self.scale = (1 - self.weight_fun(self.num_phase - 1)) / (self.max_val - self.min_val) @@ -88,74 +86,112 @@ def f(i): class LatencyEncoder(Encoder): - r"""Encode the rate input as the spike train. + r"""Encode the rate input as the spike train using the latency encoding. - The latency encoder will encode ``x`` (normalized into ``[0, 1]`` according to + Use input features to determine time-to-first spike. + + Expected inputs should be between 0 and 1. If not, the latency encoder will encode ``x`` + (normalized into ``[0, 1]`` according to :math:`x_{\text{normalize}} = \frac{x-\text{min_val}}{\text{max_val} - \text{min_val}}`) to spikes whose firing time is :math:`0 \le t_f \le \text{num_period}-1`. A larger ``x`` will cause the earlier firing time. - Parameters - ---------- - min_val: float - The minimal value in the given data `x`, used to the data normalization. - max_val: float - The maximum value in the given data `x`, used to the data normalization. - num_period: int - The periodic firing time step. - method: str - How to convert intensity to firing time. Currently, we support `linear` or `log`. - - - If ``method='linear'``, the firing rate is calculated as - :math:`t_f(x) = (\text{num_period} - 1)(1 - x)`. - - If ``method='log'``, the firing rate is calculated as - :math:`t_f(x) = (\text{num_period} - 1) - ln(\alpha * x + 1)`, - where :math:`\alpha` satisfies :math:`t_f(1) = \text{num_period} - 1`. + + Example:: + + >>> a = bm.array([0.02, 0.5, 1]) + >>> encoder = LatencyEncoder(method='linear', normalize=True) + >>> encoder.multi_steps(a, n_time=5) + Array([[0., 0., 1.], + [0., 0., 0.], + [0., 1., 0.], + [0., 0., 0.], + [1., 0., 0.]]) + + + Args: + min_val: float. The minimal value in the given data `x`, used to the data normalization. + max_val: float. The maximum value in the given data `x`, used to the data normalization. + method: str. How to convert intensity to firing time. Currently, we support `linear` or `log`. + - If ``method='linear'``, the firing rate is calculated as + :math:`t_f(x) = (\text{num_period} - 1)(1 - x)`. + - If ``method='log'``, the firing rate is calculated as + :math:`t_f(x) = (\text{num_period} - 1) - ln(\alpha * x + 1)`, + where :math:`\alpha` satisfies :math:`t_f(1) = \text{num_period} - 1`. + threshold: float. Input features below the threhold will fire at the + final time step unless ``clip=True`` in which case they will not + fire at all, defaults to ``0.01``. + clip: bool. Option to remove spikes from features that fall + below the threshold, defaults to ``False``. + tau: float. RC Time constant for LIF model used to calculate + firing time, defaults to ``1``. + normalize: bool. Option to normalize the latency code such that + the final spike(s) occur within num_steps, defaults to ``False``. + epsilon: float. A tiny positive value to avoid rounding errors when + using torch.arange, defaults to ``1e-7``. """ - def __init__(self, - min_val: float, - max_val: float, - num_period: int, - method: str = 'linear'): + def __init__( + self, + min_val: float = None, + max_val: float = None, + method: str = 'log', + threshold: float = 0.01, + clip: bool = False, + tau: float = 1., + normalize: bool = False, + first_spk_time: float = 0., + epsilon: float = 1e-7, + ): super().__init__() - check.is_integer(num_period, 'num_period', min_bound=1) - check.is_float(min_val, 'min_val') - check.is_float(max_val, 'max_val') - assert method in ['linear', 'log'] - self.num_period = num_period - self.min_val = min_val - self.max_val = max_val + if method not in ['linear', 'log']: + raise ValueError('The conversion method can only be "linear" and "log".') self.method = method - - def __call__(self, x: ArrayType, i_step: Union[int, ArrayType]): - """Encoding function. - - Parameters - ---------- - x: ArrayType - The input rate value. - i_step: int, ArrayType - The indices of the time step. - - Returns - ------- - out: ArrayType - The encoded spike train. + self.min_val = check.is_float(min_val, 'min_val', allow_none=True) + self.max_val = check.is_float(max_val, 'max_val', allow_none=True) + if threshold < 0 or threshold > 1: + raise ValueError(f"``threshold`` [{threshold}] must be between [0, 1]") + self.threshold = threshold + self.clip = clip + self.tau = tau + self.normalize = normalize + self.first_spk_time = check.is_float(first_spk_time) + self.first_spk_step = int(first_spk_time / bm.get_dt()) + self.epsilon = epsilon + + def single_step(self, x, i_step: int = None): + raise NotImplementedError + + def multi_steps(self, data, n_time: Optional[float] = None): + """Generate latency spikes according to the given input data. + + Ensuring x in [0., 1.]. + + Args: + data: The rate-based input. + n_time: float. The total time to generate data. If None, use ``tau`` instead. + + Returns: + out: array. The output spiking trains. """ - _temp = self.num_period - 1. - if self.method == 'log': - alpha = math.exp(_temp) - 1. - t_f = bm.round(_temp - bm.log(alpha * x + 1.)).astype(bm.int_) - else: - t_f = bm.round(_temp * (1. - x)).astype(bm.int_) + if n_time is None: + n_time = self.tau + tau = n_time if self.normalize else self.tau + x = data + if self.min_val is not None and self.max_val is not None: + x = (x - self.min_val) / (self.max_val - self.min_val) + if self.method == 'linear': + spike_time = (tau - self.first_spk_time - bm.dt) * (1 - x) + self.first_spk_time + + elif self.method == 'log': + x = bm.maximum(x, self.threshold + self.epsilon) # saturates all values below threshold. + spike_time = (tau - self.first_spk_time - bm.dt) * bm.log(x / (x - self.threshold)) + self.first_spk_time - def f(i): - return bm.as_jax(t_f == (i % self.num_period), dtype=x.dtype) - - if isinstance(i_step, int): - return f(i_step) else: - assert isinstance(i_step, (jax.Array, bm.Array)) - return jax.vmap(f, i_step) + raise ValueError(f'Unsupported method: {self.method}. Only support "log" and "linear".') + + if self.clip: + spike_time = bm.where(data < self.threshold, np.inf, spike_time) + spike_steps = bm.round(spike_time / bm.get_dt()).astype(int) + return bm.one_hot(spike_steps, num_classes=int(n_time / bm.get_dt()), axis=0, dtype=x.dtype) diff --git a/brainpy/_src/encoding/stateless_encoding.py b/brainpy/_src/encoding/stateless_encoding.py index 700a6330c..5410d736c 100644 --- a/brainpy/_src/encoding/stateless_encoding.py +++ b/brainpy/_src/encoding/stateless_encoding.py @@ -1,68 +1,189 @@ # -*- coding: utf-8 -*- -from typing import Union, Optional +from typing import Optional -import jax import brainpy.math as bm from brainpy import check -from brainpy.types import ArrayType from .base import Encoder __all__ = [ 'PoissonEncoder', + 'DiffEncoder', ] class PoissonEncoder(Encoder): r"""Encode the rate input as the Poisson spike train. - Given the input :math:`x`, the poisson encoder will output - spikes whose firing probability is :math:`x_{\text{normalize}}`, where - :math:`x_{\text{normalize}}` is normalized into ``[0, 1]`` according + Expected inputs should be between 0 and 1. If not, the input :math:`x` will be + normalized to :math:`x_{\text{normalize}}` within ``[0, 1]`` according to :math:`x_{\text{normalize}} = \frac{x-\text{min_val}}{\text{max_val} - \text{min_val}}`. - Parameters - ---------- - min_val: float - The minimal value in the given data `x`, used to the data normalization. - max_val: float - The maximum value in the given data `x`, used to the data normalization. - seed: int, ArrayType - The seed or key for random generation. + Given the input :math:`x`, the poisson encoder will output + spikes whose firing probability is :math:`x_{\text{normalize}}`. + + + Examples:: + + import brainpy as bp + import brainpy.math as bm + + img = bm.random.random((10, 2)) # image to encode (normalized to [0., 1.]) + encoder = bp.encoding.PoissonEncoder() # the encoder + + # encode the image at each time + for run_index in range(100): + spike = encoder.single_step(img) + # do something + + # or, encode the image at multiple times once + spikes = encoder.multi_steps(img, n_time=10.) + + + Args: + min_val: float. The minimal value in the given data `x`, used to the data normalization. + max_val: float. The maximum value in the given data `x`, used to the data normalization. + gain: float. Scale input features by the gain, defaults to ``1``. + offset: float. Shift input features by the offset, defaults to ``0``. + first_spk_time: float. The time to first spike, defaults to ``0``. """ - def __init__(self, - min_val: Optional[float] = None, - max_val: Optional[float] = None): + def __init__( + self, + min_val: Optional[float] = None, + max_val: Optional[float] = None, + gain: float = 1.0, + offset: float = 0.0, + first_spk_time: float = 0., + ): super().__init__() self.min_val = check.is_float(min_val, 'min_val', allow_none=True) self.max_val = check.is_float(max_val, 'max_val', allow_none=True) + self.gain = check.is_float(gain, allow_none=False) + self.offset = check.is_float(offset, allow_none=False) + self.first_spk_time = check.is_float(first_spk_time) + self.first_spk_step = int(self.first_spk_time / bm.get_dt()) + + def single_step(self, x, i_step: int = None): + """Generate spikes at the single step according to the inputs. + + Args: + x: Array. The rate input. + i_step: int. The time step to generate spikes. - def __call__(self, x: ArrayType, num_step: int = None): + Returns: + out: Array. The encoded spike train. """ + if i_step is None: + return self.multi_steps(x, n_time=None) + else: + return bm.cond(bm.as_jax(i_step < self.first_spk_step), self._zero_out, self.multi_steps, x) + + def multi_steps(self, x, n_time: Optional[float]): + """Generate spikes at multiple steps according to the inputs. + + Args: + x: Array. The rate input. + n_time: float. Encode rate values as spike trains in the given time length. + ``n_time`` is converted into the ``n_step`` according to `n_step = int(n_time / brainpy.math.dt)`. + - If ``n_time=None``, encode the rate values at the current time step. + Users should repeatedly call it to encode `x` as a spike train. + - Else, given the ``x`` with shape ``(S, ...)``, the encoded + spike train is the array with shape ``(n_step, S, ...)``. - Parameters - ---------- - x: ArrayType - The rate input. - num_step: int - Encode rate values as spike trains in the given time length. - - - If ``time_len=None``, encode the rate values at the current time step. - Users should repeatedly call it to encode `x` as a spike train. - - Else, given the ``x`` with shape ``(S, ...)``, the encoded - spike train is the array with shape ``(time_len, S, ...)``. - - Returns - ------- - out: ArrayType - The encoded spike train. + Returns: + out: Array. The encoded spike train. """ - with jax.ensure_compile_time_eval(): - check.is_integer(num_step, 'time_len', min_bound=1, allow_none=True) - if not (self.min_val is None or self.max_val is None): + n_time = int(n_time / bm.get_dt()) + + if (self.min_val is not None) and (self.max_val is not None): x = (x - self.min_val) / (self.max_val - self.min_val) - shape = x.shape if (num_step is None) else ((num_step,) + x.shape) - d = bm.as_jax(bm.random.rand(*shape)) < x - return d.astype(x.dtype) + x = x * self.gain + self.offset + if n_time is not None and self.first_spk_step > 0: + pre = bm.zeros((self.first_spk_step,) + x.shape, dtype=x.dtype) + shape = ((n_time - self.first_spk_step,) + x.shape) + post = bm.asarray(bm.random.rand(*shape) < x, dtype=x.dtype) + return bm.cat([pre, post], axis=0) + else: + shape = x.shape if (n_time is None) else ((n_time - self.first_spk_step,) + x.shape) + return bm.asarray(bm.random.rand(*shape) < x, dtype=x.dtype) + + def _zero_out(self, x): + return bm.zeros_like(x) + + +class DiffEncoder(Encoder): + """Generate spike only when the difference between two subsequent + time steps meets a threshold. + + Optionally include `off_spikes` for negative changes. + + Example:: + + >>> a = bm.array([1, 2, 2.9, 3, 3.9]) + >>> encoder = DiffEncoder(threshold=1) + >>> encoder.multi_steps(a) + Array([1., 0., 0., 0.]) + + >>> encoder = DiffEncoder(threshold=1, padding=True) + >>> encoder.multi_steps(a) + Array([0., 1., 0., 0., 0.]) + + >>> b = bm.array([1, 2, 0, 2, 2.9]) + >>> encoder = DiffEncoder(threshold=1, off_spike=True) + >>> encoder.multi_steps(b) + Array([ 1., 1., -1., 1., 0.]) + + >>> encoder = DiffEncoder(threshold=1, padding=True, off_spike=True) + >>> encoder.multi_steps(b) + Array([ 0., 1., -1., 1., 0.]) + + Args: + threshold: float. Input features with a change greater than the thresold + across one timestep will generate a spike, defaults to ``0.1``. + padding: bool. Used to change how the first time step of spikes are + measured. If ``True``, the first time step will be repeated with itself + resulting in ``0``'s for the output spikes. + If ``False``, the first time step will be padded with ``0``'s, defaults + to ``False``. + off_spike: bool. If ``True``, negative spikes for changes less than + ``-threshold``, defaults to ``False``. + """ + + def __init__( + self, + threshold: float = 0.1, + padding: bool = False, + off_spike: bool = False, + ): + super().__init__() + + self.threshold = threshold + self.padding = padding + self.off_spike = off_spike + + def single_step(self, *args, **kwargs): + raise NotImplementedError(f'{DiffEncoder.__class__.__name__} does not support single-step encoding.') + + def multi_steps(self, x): + """Encoding multistep inputs with the spiking trains. + + Args: + x: Array. The array with the shape of `(num_step, ....)`. + + Returns: + out: Array. The spike train. + """ + if self.padding: + diff = bm.diff(x, axis=0, prepend=x[:1]) + else: + diff = bm.diff(x, axis=0, prepend=bm.zeros((1,) + x.shape[1:], dtype=x.dtype)) + + if self.off_spike: + on_spk = bm.asarray(diff >= self.threshold, dtype=x.dtype) + off_spk = -bm.asarray(diff <= -self.threshold, dtype=x.dtype) + return on_spk + off_spk + + else: + return bm.asarray(diff >= self.threshold, dtype=x.dtype) diff --git a/brainpy/_src/encoding/tests/test_stateless_encoding.py b/brainpy/_src/encoding/tests/test_stateless_encoding.py new file mode 100644 index 000000000..3fec2a964 --- /dev/null +++ b/brainpy/_src/encoding/tests/test_stateless_encoding.py @@ -0,0 +1,79 @@ +import unittest +import brainpy.math as bm +import brainpy as bp + + +class TestDiffEncoder(unittest.TestCase): + def test_delta(self): + a = bm.array([1, 2, 2.9, 3, 3.9]) + encoder = bp.encoding.DiffEncoder(threshold=1) + r = encoder.multi_steps(a) + excepted = bm.asarray([1., 1., 0., 0., 0.]) + self.assertTrue(bm.allclose(r, excepted)) + + encoder = bp.encoding.DiffEncoder(threshold=1, padding=True) + r = encoder.multi_steps(a) + excepted = bm.asarray([0., 1., 0., 0., 0.]) + self.assertTrue(bm.allclose(r, excepted)) + + bm.clear_buffer_memory() + + def test_delta_off_spike(self): + b = bm.array([1, 2, 0, 2, 2.9]) + encoder = bp.encoding.DiffEncoder(threshold=1, off_spike=True) + r = encoder.multi_steps(b) + excepted = bm.asarray([1., 1., -1., 1., 0.]) + self.assertTrue(bm.allclose(r, excepted)) + + encoder = bp.encoding.DiffEncoder(threshold=1, padding=True, off_spike=True) + r = encoder.multi_steps(b) + excepted = bm.asarray([0., 1., -1., 1., 0.]) + self.assertTrue(bm.allclose(r, excepted)) + + bm.clear_buffer_memory() + + +class TestLatencyEncoder(unittest.TestCase): + def test_latency(self): + a = bm.array([0.02, 0.5, 1]) + encoder = bp.encoding.LatencyEncoder(method='linear') + + r = encoder.multi_steps(a, n_time=0.5) + excepted = bm.asarray( + [[0., 0., 1.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 1., 0.], + ] + ) + self.assertTrue(bm.allclose(r, excepted)) + + r = encoder.multi_steps(a, n_time=1.0) + excepted = bm.asarray( + [[0., 0., 1.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 1., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [1., 0., 0.], + ] + ) + self.assertTrue(bm.allclose(r, excepted)) + + encoder = bp.encoding.LatencyEncoder(method='linear', normalize=True) + r = encoder.multi_steps(a, n_time=0.5) + excepted = bm.asarray( + [[0., 0., 1.], + [0., 0., 0.], + [0., 1., 0.], + [0., 0., 0.], + [1., 0., 0.], + ] + ) + self.assertTrue(bm.allclose(r, excepted)) + diff --git a/brainpy/encoding.py b/brainpy/encoding.py index 4a2de0be7..b51f9d744 100644 --- a/brainpy/encoding.py +++ b/brainpy/encoding.py @@ -9,6 +9,7 @@ WeightedPhaseEncoder as WeightedPhaseEncoder, ) from brainpy._src.encoding.stateless_encoding import ( - PoissonEncoder as PoissonEncoder + PoissonEncoder as PoissonEncoder, + DiffEncoder as DiffEncoder, ) From fe90ddeff8f43bf1d2233e9ba55a7e2208a4ae3d Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 29 Aug 2023 10:09:20 +0800 Subject: [PATCH 2/2] common updates --- brainpy/_add_deprecations.py | 12 --- brainpy/_src/dyn/others/common.py | 5 +- brainpy/_src/math/compat_pytorch.py | 76 +++++++++------ brainpy/_src/visualization/animation.py | 121 ++++++++++++++++++++++++ brainpy/_src/visualization/base.py | 5 + brainpy/math/compat_pytorch.py | 2 + 6 files changed, 180 insertions(+), 41 deletions(-) create mode 100644 brainpy/_src/visualization/animation.py diff --git a/brainpy/_add_deprecations.py b/brainpy/_add_deprecations.py index 89fd1dd8c..741728ef4 100644 --- a/brainpy/_add_deprecations.py +++ b/brainpy/_add_deprecations.py @@ -102,18 +102,6 @@ dyn.__getattr__ = deprecation_getattr2('brainpy.dyn', dyn.__deprecations) -# dnn.__deprecations = { -# 'Layer': ('brainpy.dnn.Layer', 'brainpy.AnnLayer', AnnLayer), -# } -# dnn.__getattr__ = deprecation_getattr2('brainpy.dnn', dnn.__deprecations) - - -# layers.__deprecations = { -# 'Layer': ('brainpy.layers.Layer', 'brainpy.AnnLayer', AnnLayer), -# } -# layers.__getattr__ = deprecation_getattr2('brainpy.layers', layers.__deprecations) - - connect.__deprecations = { 'one2one': ('brainpy.connect.one2one', 'brainpy.connect.One2One', connect.One2One), 'all2all': ('brainpy.connect.all2all', 'brainpy.connect.All2All', connect.All2All), diff --git a/brainpy/_src/dyn/others/common.py b/brainpy/_src/dyn/others/common.py index ef069d4ea..b5be6b23a 100644 --- a/brainpy/_src/dyn/others/common.py +++ b/brainpy/_src/dyn/others/common.py @@ -76,8 +76,9 @@ def update(self, inp=None): t = share.load('t') dt = share.load('dt') self.x.value = self.integral(self.x.value, t, dt) - if inp is not None: - self.x += inp + if inp is None: inp = 0. + inp = self.sum_inputs(self.x.value, init=inp) + self.x += inp return self.x.value def return_info(self): diff --git a/brainpy/_src/math/compat_pytorch.py b/brainpy/_src/math/compat_pytorch.py index 419f2d146..86695e440 100644 --- a/brainpy/_src/math/compat_pytorch.py +++ b/brainpy/_src/math/compat_pytorch.py @@ -6,7 +6,7 @@ from .ndarray import Array, _as_jax_array_, _return, _check_out from .compat_numpy import ( - concatenate, shape + concatenate, shape, minimum, maximum, ) __all__ = [ @@ -31,9 +31,10 @@ 'arctan', 'atan2', 'atanh', + 'clamp_max', + 'clamp_min', ] - Tensor = Array cat = concatenate @@ -80,28 +81,28 @@ def flatten(input: Union[jax.Array, Array], raise ValueError(f'start_dim {start_dim} is out of size.') if end_dim < 0 or end_dim > ndim: raise ValueError(f'end_dim {end_dim} is out of size.') - new_shape = shape[:start_dim] + (np.prod(shape[start_dim: end_dim], dtype=int), ) + shape[end_dim:] + new_shape = shape[:start_dim] + (np.prod(shape[start_dim: end_dim], dtype=int),) + shape[end_dim:] return jnp.reshape(input, new_shape) def unsqueeze(input: Union[jax.Array, Array], dim: int) -> Array: - """Returns a new tensor with a dimension of size one inserted at the specified position. - The returned tensor shares the same underlying data with this tensor. - A dim value within the range [-input.dim() - 1, input.dim() + 1) can be used. - Negative dim will correspond to unsqueeze() applied at dim = dim + input.dim() + 1. - Parameters - ---------- - input: Array - The input Array - dim: int - The index at which to insert the singleton dimension - - Returns - ------- - out: Array - """ - input = _as_jax_array_(input) - return Array(jnp.expand_dims(input, dim)) + """Returns a new tensor with a dimension of size one inserted at the specified position. +The returned tensor shares the same underlying data with this tensor. +A dim value within the range [-input.dim() - 1, input.dim() + 1) can be used. +Negative dim will correspond to unsqueeze() applied at dim = dim + input.dim() + 1. +Parameters +---------- +input: Array + The input Array +dim: int + The index at which to insert the singleton dimension + +Returns +------- +out: Array +""" + input = _as_jax_array_(input) + return Array(jnp.expand_dims(input, dim)) # Math operations @@ -115,10 +116,12 @@ def abs(input: Union[jax.Array, Array], _check_out(out) out.value = r + absolute = abs + def acos(input: Union[jax.Array, Array], - *, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]: + *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: input = _as_jax_array_(input) r = jnp.arccos(input) if out is None: @@ -127,10 +130,12 @@ def acos(input: Union[jax.Array, Array], _check_out(out) out.value = r + arccos = acos + def acosh(input: Union[jax.Array, Array], - *, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]: + *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: input = _as_jax_array_(input) r = jnp.arccosh(input) if out is None: @@ -139,8 +144,10 @@ def acosh(input: Union[jax.Array, Array], _check_out(out) out.value = r + arccosh = acosh + def add(input: Union[jax.Array, Array, jnp.number], other: Union[jax.Array, Array, jnp.number], *, alpha: Optional[jnp.number] = 1, @@ -155,6 +162,7 @@ def add(input: Union[jax.Array, Array, jnp.number], _check_out(out) out.value = r + def addcdiv(input: Union[jax.Array, Array, jnp.number], tensor1: Union[jax.Array, Array, jnp.number], tensor2: Union[jax.Array, Array, jnp.number], @@ -165,7 +173,8 @@ def addcdiv(input: Union[jax.Array, Array, jnp.number], other = jnp.divide(tensor1, tensor2) return add(input, other, alpha=value, out=out) -def addcmul(input: Union[jax.Array, Array, jnp.number], + +def addcmul(input: Union[jax.Array, Array, jnp.number], tensor1: Union[jax.Array, Array, jnp.number], tensor2: Union[jax.Array, Array, jnp.number], *, value: jnp.number = 1, @@ -175,6 +184,7 @@ def addcmul(input: Union[jax.Array, Array, jnp.number], other = jnp.multiply(tensor1, tensor2) return add(input, other, alpha=value, out=out) + def angle(input: Union[jax.Array, Array, jnp.number], *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: input = _as_jax_array_(input) @@ -185,8 +195,9 @@ def angle(input: Union[jax.Array, Array, jnp.number], _check_out(out) out.value = r + def asin(input: Union[jax.Array, Array], - *, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]: + *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: input = _as_jax_array_(input) r = jnp.arcsin(input) if out is None: @@ -195,10 +206,12 @@ def asin(input: Union[jax.Array, Array], _check_out(out) out.value = r + arcsin = asin + def asinh(input: Union[jax.Array, Array], - *, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]: + *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: input = _as_jax_array_(input) r = jnp.arcsinh(input) if out is None: @@ -207,10 +220,12 @@ def asinh(input: Union[jax.Array, Array], _check_out(out) out.value = r + arcsinh = asinh + def atan(input: Union[jax.Array, Array], - *, out: Optional[Union[Array,jax.Array, np.ndarray]] = None) -> Optional[Array]: + *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: input = _as_jax_array_(input) r = jnp.arctan(input) if out is None: @@ -219,8 +234,10 @@ def atan(input: Union[jax.Array, Array], _check_out(out) out.value = r + arctan = atan + def atanh(input: Union[jax.Array, Array], *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: input = _as_jax_array_(input) @@ -231,8 +248,10 @@ def atanh(input: Union[jax.Array, Array], _check_out(out) out.value = r + arctanh = atanh + def atan2(input: Union[jax.Array, Array], *, out: Optional[Union[Array, jax.Array, np.ndarray]] = None) -> Optional[Array]: input = _as_jax_array_(input) @@ -243,4 +262,7 @@ def atan2(input: Union[jax.Array, Array], _check_out(out) out.value = r -arctan2 = atan2 \ No newline at end of file + +arctan2 = atan2 +clamp_max = minimum +clamp_min = maximum diff --git a/brainpy/_src/visualization/animation.py b/brainpy/_src/visualization/animation.py new file mode 100644 index 000000000..6848799c1 --- /dev/null +++ b/brainpy/_src/visualization/animation.py @@ -0,0 +1,121 @@ +from collections import defaultdict +from typing import Dict, List + +import matplotlib.pyplot as plt +from matplotlib.animation import ArtistAnimation +from matplotlib.artist import Artist +from matplotlib.figure import Figure + +import brainpy.math as bm + +__all__ = [ + 'animator', +] + + +def animator(data, fig, ax, num_steps=False, interval=40, cmap="plasma"): + """Generate an animation by looping through the first dimension of a + sample of spiking data. + Time must be the first dimension of ``data``. + + Example:: + + import matplotlib.pyplot as plt + + # Index into a single sample from a minibatch + spike_data_sample = bm.random.rand(100, 28, 28) + print(spike_data_sample.shape) + >>> (100, 28, 28) + + # Plot + fig, ax = plt.subplots() + anim = splt.animator(spike_data_sample, fig, ax) + HTML(anim.to_html5_video()) + + # Save as a gif + anim.save("spike_mnist.gif") + + :param data: Data tensor for a single sample across time steps of + shape [num_steps x input_size] + :type data: torch.Tensor + + :param fig: Top level container for all plot elements + :type fig: matplotlib.figure.Figure + + :param ax: Contains additional figure elements and sets the coordinate + system. E.g.: + fig, ax = plt.subplots(facecolor='w', figsize=(12, 7)) + :type ax: matplotlib.axes._subplots.AxesSubplot + + :param num_steps: Number of time steps to plot. If not specified, + the number of entries in the first dimension + of ``data`` will automatically be used, defaults to ``False`` + :type num_steps: int, optional + + :param interval: Delay between frames in milliseconds, defaults to ``40`` + :type interval: int, optional + + :param cmap: color map, defaults to ``plasma`` + :type cmap: string, optional + + :return: animation to be displayed using ``matplotlib.pyplot.show()`` + :rtype: FuncAnimation + + """ + + data = bm.as_numpy(data) + if not num_steps: + num_steps = data.shape[0] + camera = Camera(fig) + plt.axis("off") + # iterate over time and take a snapshot with celluloid + for step in range( + num_steps + ): # im appears unused but is required by camera.snap() + im = ax.imshow(data[step], cmap=cmap) # noqa: F841 + camera.snap() + anim = camera.animate(interval=interval) + return anim + + +class Camera: + """Make animations easier.""" + + def __init__(self, figure: Figure) -> None: + """Create camera from matplotlib figure.""" + self._figure = figure + # need to keep track off artists for each axis + self._offsets: Dict[str, Dict[int, int]] = { + k: defaultdict(int) + for k in [ + "collections", + "patches", + "lines", + "texts", + "artists", + "images", + ] + } + self._photos: List[List[Artist]] = [] + + def snap(self) -> List[Artist]: + """Capture current state of the figure.""" + frame_artists: List[Artist] = [] + for i, axis in enumerate(self._figure.axes): + if axis.legend_ is not None: + axis.add_artist(axis.legend_) + for name in self._offsets: + new_artists = getattr(axis, name)[self._offsets[name][i]:] + frame_artists += new_artists + self._offsets[name][i] += len(new_artists) + self._photos.append(frame_artists) + return frame_artists + + def animate(self, *args, **kwargs) -> ArtistAnimation: + """Animate the snapshots taken. + Uses matplotlib.animation.ArtistAnimation + Returns + ------- + ArtistAnimation + """ + return ArtistAnimation(self._figure, self._photos, *args, **kwargs) diff --git a/brainpy/_src/visualization/base.py b/brainpy/_src/visualization/base.py index 36a67ea7c..efd33cdc8 100644 --- a/brainpy/_src/visualization/base.py +++ b/brainpy/_src/visualization/base.py @@ -105,3 +105,8 @@ def plot_style1(fontsize=22, lw=1): from .styles import plot_style1 plot_style1(fontsize=fontsize, axes_edgecolor=axes_edgecolor, figsize=figsize, lw=lw) + + @staticmethod + def animator(data, fig, ax, num_steps=False, interval=40, cmap="plasma"): + from .animation import animator + return animator(data, fig, ax, num_steps=num_steps, interval=interval, cmap=cmap) diff --git a/brainpy/math/compat_pytorch.py b/brainpy/math/compat_pytorch.py index 919134aac..f522b6ab7 100644 --- a/brainpy/math/compat_pytorch.py +++ b/brainpy/math/compat_pytorch.py @@ -23,4 +23,6 @@ arctan as arctan, atan2 as atan2, atanh as atanh, + clamp_max, + clamp_min, )