diff --git a/brainpy/_src/encoding/base.py b/brainpy/_src/encoding/base.py index c85a0b98c..d2a53242d 100644 --- a/brainpy/_src/encoding/base.py +++ b/brainpy/_src/encoding/base.py @@ -10,8 +10,13 @@ class Encoder(BrainPyObject): """Base class for encoding rate values as spike trains.""" - def __call__(self, *args, **kwargs): - raise NotImplementedError def __repr__(self): return self.__class__.__name__ + + def single_step(self, *args, **kwargs): + raise NotImplementedError('Please implement the function for single step encoding.') + + def multi_steps(self, *args, **kwargs): + raise NotImplementedError('Encode implement the function for multiple-step encoding.') + diff --git a/brainpy/_src/encoding/stateful_encoding.py b/brainpy/_src/encoding/stateful_encoding.py index b40e4f427..c2b6ced2e 100644 --- a/brainpy/_src/encoding/stateful_encoding.py +++ b/brainpy/_src/encoding/stateful_encoding.py @@ -1,9 +1,10 @@ # -*- coding: utf-8 -*- import math -from typing import Union, Callable +from typing import Union, Callable, Optional import jax +import numpy as np import brainpy.math as bm from brainpy import check @@ -47,13 +48,10 @@ def __init__(self, weight_fun: Callable = None): super().__init__() - check.is_integer(num_phase, 'num_phase', min_bound=1) - check.is_float(min_val, 'min_val') - check.is_float(max_val, 'max_val') check.is_callable(weight_fun, 'weight_fun', allow_none=True) - self.num_phase = num_phase - self.min_val = min_val - self.max_val = max_val + self.num_phase = check.is_integer(num_phase, 'num_phase', min_bound=1) + self.min_val = check.is_float(min_val, 'min_val') + self.max_val = check.is_float(max_val, 'max_val') self.weight_fun = (lambda i: 2 ** (-(i % num_phase + 1))) if weight_fun is None else weight_fun self.scale = (1 - self.weight_fun(self.num_phase - 1)) / (self.max_val - self.min_val) @@ -88,74 +86,112 @@ def f(i): class LatencyEncoder(Encoder): - r"""Encode the rate input as the spike train. + r"""Encode the rate input as the spike train using the latency encoding. - The latency encoder will encode ``x`` (normalized into ``[0, 1]`` according to + Use input features to determine time-to-first spike. + + Expected inputs should be between 0 and 1. If not, the latency encoder will encode ``x`` + (normalized into ``[0, 1]`` according to :math:`x_{\text{normalize}} = \frac{x-\text{min_val}}{\text{max_val} - \text{min_val}}`) to spikes whose firing time is :math:`0 \le t_f \le \text{num_period}-1`. A larger ``x`` will cause the earlier firing time. - Parameters - ---------- - min_val: float - The minimal value in the given data `x`, used to the data normalization. - max_val: float - The maximum value in the given data `x`, used to the data normalization. - num_period: int - The periodic firing time step. - method: str - How to convert intensity to firing time. Currently, we support `linear` or `log`. - - - If ``method='linear'``, the firing rate is calculated as - :math:`t_f(x) = (\text{num_period} - 1)(1 - x)`. - - If ``method='log'``, the firing rate is calculated as - :math:`t_f(x) = (\text{num_period} - 1) - ln(\alpha * x + 1)`, - where :math:`\alpha` satisfies :math:`t_f(1) = \text{num_period} - 1`. + + Example:: + + >>> a = bm.array([0.02, 0.5, 1]) + >>> encoder = LatencyEncoder(method='linear', normalize=True) + >>> encoder.multi_steps(a, n_time=5) + Array([[0., 0., 1.], + [0., 0., 0.], + [0., 1., 0.], + [0., 0., 0.], + [1., 0., 0.]]) + + + Args: + min_val: float. The minimal value in the given data `x`, used to the data normalization. + max_val: float. The maximum value in the given data `x`, used to the data normalization. + method: str. How to convert intensity to firing time. Currently, we support `linear` or `log`. + - If ``method='linear'``, the firing rate is calculated as + :math:`t_f(x) = (\text{num_period} - 1)(1 - x)`. + - If ``method='log'``, the firing rate is calculated as + :math:`t_f(x) = (\text{num_period} - 1) - ln(\alpha * x + 1)`, + where :math:`\alpha` satisfies :math:`t_f(1) = \text{num_period} - 1`. + threshold: float. Input features below the threhold will fire at the + final time step unless ``clip=True`` in which case they will not + fire at all, defaults to ``0.01``. + clip: bool. Option to remove spikes from features that fall + below the threshold, defaults to ``False``. + tau: float. RC Time constant for LIF model used to calculate + firing time, defaults to ``1``. + normalize: bool. Option to normalize the latency code such that + the final spike(s) occur within num_steps, defaults to ``False``. + epsilon: float. A tiny positive value to avoid rounding errors when + using torch.arange, defaults to ``1e-7``. """ - def __init__(self, - min_val: float, - max_val: float, - num_period: int, - method: str = 'linear'): + def __init__( + self, + min_val: float = None, + max_val: float = None, + method: str = 'log', + threshold: float = 0.01, + clip: bool = False, + tau: float = 1., + normalize: bool = False, + first_spk_time: float = 0., + epsilon: float = 1e-7, + ): super().__init__() - check.is_integer(num_period, 'num_period', min_bound=1) - check.is_float(min_val, 'min_val') - check.is_float(max_val, 'max_val') - assert method in ['linear', 'log'] - self.num_period = num_period - self.min_val = min_val - self.max_val = max_val + if method not in ['linear', 'log']: + raise ValueError('The conversion method can only be "linear" and "log".') self.method = method - - def __call__(self, x: ArrayType, i_step: Union[int, ArrayType]): - """Encoding function. - - Parameters - ---------- - x: ArrayType - The input rate value. - i_step: int, ArrayType - The indices of the time step. - - Returns - ------- - out: ArrayType - The encoded spike train. + self.min_val = check.is_float(min_val, 'min_val', allow_none=True) + self.max_val = check.is_float(max_val, 'max_val', allow_none=True) + if threshold < 0 or threshold > 1: + raise ValueError(f"``threshold`` [{threshold}] must be between [0, 1]") + self.threshold = threshold + self.clip = clip + self.tau = tau + self.normalize = normalize + self.first_spk_time = check.is_float(first_spk_time) + self.first_spk_step = int(first_spk_time / bm.get_dt()) + self.epsilon = epsilon + + def single_step(self, x, i_step: int = None): + raise NotImplementedError + + def multi_steps(self, data, n_time: Optional[float] = None): + """Generate latency spikes according to the given input data. + + Ensuring x in [0., 1.]. + + Args: + data: The rate-based input. + n_time: float. The total time to generate data. If None, use ``tau`` instead. + + Returns: + out: array. The output spiking trains. """ - _temp = self.num_period - 1. - if self.method == 'log': - alpha = math.exp(_temp) - 1. - t_f = bm.round(_temp - bm.log(alpha * x + 1.)).astype(bm.int_) - else: - t_f = bm.round(_temp * (1. - x)).astype(bm.int_) + if n_time is None: + n_time = self.tau + tau = n_time if self.normalize else self.tau + x = data + if self.min_val is not None and self.max_val is not None: + x = (x - self.min_val) / (self.max_val - self.min_val) + if self.method == 'linear': + spike_time = (tau - self.first_spk_time - bm.dt) * (1 - x) + self.first_spk_time + + elif self.method == 'log': + x = bm.maximum(x, self.threshold + self.epsilon) # saturates all values below threshold. + spike_time = (tau - self.first_spk_time - bm.dt) * bm.log(x / (x - self.threshold)) + self.first_spk_time - def f(i): - return bm.as_jax(t_f == (i % self.num_period), dtype=x.dtype) - - if isinstance(i_step, int): - return f(i_step) else: - assert isinstance(i_step, (jax.Array, bm.Array)) - return jax.vmap(f, i_step) + raise ValueError(f'Unsupported method: {self.method}. Only support "log" and "linear".') + + if self.clip: + spike_time = bm.where(data < self.threshold, np.inf, spike_time) + spike_steps = bm.round(spike_time / bm.get_dt()).astype(int) + return bm.one_hot(spike_steps, num_classes=int(n_time / bm.get_dt()), axis=0, dtype=x.dtype) diff --git a/brainpy/_src/encoding/stateless_encoding.py b/brainpy/_src/encoding/stateless_encoding.py index 700a6330c..5410d736c 100644 --- a/brainpy/_src/encoding/stateless_encoding.py +++ b/brainpy/_src/encoding/stateless_encoding.py @@ -1,68 +1,189 @@ # -*- coding: utf-8 -*- -from typing import Union, Optional +from typing import Optional -import jax import brainpy.math as bm from brainpy import check -from brainpy.types import ArrayType from .base import Encoder __all__ = [ 'PoissonEncoder', + 'DiffEncoder', ] class PoissonEncoder(Encoder): r"""Encode the rate input as the Poisson spike train. - Given the input :math:`x`, the poisson encoder will output - spikes whose firing probability is :math:`x_{\text{normalize}}`, where - :math:`x_{\text{normalize}}` is normalized into ``[0, 1]`` according + Expected inputs should be between 0 and 1. If not, the input :math:`x` will be + normalized to :math:`x_{\text{normalize}}` within ``[0, 1]`` according to :math:`x_{\text{normalize}} = \frac{x-\text{min_val}}{\text{max_val} - \text{min_val}}`. - Parameters - ---------- - min_val: float - The minimal value in the given data `x`, used to the data normalization. - max_val: float - The maximum value in the given data `x`, used to the data normalization. - seed: int, ArrayType - The seed or key for random generation. + Given the input :math:`x`, the poisson encoder will output + spikes whose firing probability is :math:`x_{\text{normalize}}`. + + + Examples:: + + import brainpy as bp + import brainpy.math as bm + + img = bm.random.random((10, 2)) # image to encode (normalized to [0., 1.]) + encoder = bp.encoding.PoissonEncoder() # the encoder + + # encode the image at each time + for run_index in range(100): + spike = encoder.single_step(img) + # do something + + # or, encode the image at multiple times once + spikes = encoder.multi_steps(img, n_time=10.) + + + Args: + min_val: float. The minimal value in the given data `x`, used to the data normalization. + max_val: float. The maximum value in the given data `x`, used to the data normalization. + gain: float. Scale input features by the gain, defaults to ``1``. + offset: float. Shift input features by the offset, defaults to ``0``. + first_spk_time: float. The time to first spike, defaults to ``0``. """ - def __init__(self, - min_val: Optional[float] = None, - max_val: Optional[float] = None): + def __init__( + self, + min_val: Optional[float] = None, + max_val: Optional[float] = None, + gain: float = 1.0, + offset: float = 0.0, + first_spk_time: float = 0., + ): super().__init__() self.min_val = check.is_float(min_val, 'min_val', allow_none=True) self.max_val = check.is_float(max_val, 'max_val', allow_none=True) + self.gain = check.is_float(gain, allow_none=False) + self.offset = check.is_float(offset, allow_none=False) + self.first_spk_time = check.is_float(first_spk_time) + self.first_spk_step = int(self.first_spk_time / bm.get_dt()) + + def single_step(self, x, i_step: int = None): + """Generate spikes at the single step according to the inputs. + + Args: + x: Array. The rate input. + i_step: int. The time step to generate spikes. - def __call__(self, x: ArrayType, num_step: int = None): + Returns: + out: Array. The encoded spike train. """ + if i_step is None: + return self.multi_steps(x, n_time=None) + else: + return bm.cond(bm.as_jax(i_step < self.first_spk_step), self._zero_out, self.multi_steps, x) + + def multi_steps(self, x, n_time: Optional[float]): + """Generate spikes at multiple steps according to the inputs. + + Args: + x: Array. The rate input. + n_time: float. Encode rate values as spike trains in the given time length. + ``n_time`` is converted into the ``n_step`` according to `n_step = int(n_time / brainpy.math.dt)`. + - If ``n_time=None``, encode the rate values at the current time step. + Users should repeatedly call it to encode `x` as a spike train. + - Else, given the ``x`` with shape ``(S, ...)``, the encoded + spike train is the array with shape ``(n_step, S, ...)``. - Parameters - ---------- - x: ArrayType - The rate input. - num_step: int - Encode rate values as spike trains in the given time length. - - - If ``time_len=None``, encode the rate values at the current time step. - Users should repeatedly call it to encode `x` as a spike train. - - Else, given the ``x`` with shape ``(S, ...)``, the encoded - spike train is the array with shape ``(time_len, S, ...)``. - - Returns - ------- - out: ArrayType - The encoded spike train. + Returns: + out: Array. The encoded spike train. """ - with jax.ensure_compile_time_eval(): - check.is_integer(num_step, 'time_len', min_bound=1, allow_none=True) - if not (self.min_val is None or self.max_val is None): + n_time = int(n_time / bm.get_dt()) + + if (self.min_val is not None) and (self.max_val is not None): x = (x - self.min_val) / (self.max_val - self.min_val) - shape = x.shape if (num_step is None) else ((num_step,) + x.shape) - d = bm.as_jax(bm.random.rand(*shape)) < x - return d.astype(x.dtype) + x = x * self.gain + self.offset + if n_time is not None and self.first_spk_step > 0: + pre = bm.zeros((self.first_spk_step,) + x.shape, dtype=x.dtype) + shape = ((n_time - self.first_spk_step,) + x.shape) + post = bm.asarray(bm.random.rand(*shape) < x, dtype=x.dtype) + return bm.cat([pre, post], axis=0) + else: + shape = x.shape if (n_time is None) else ((n_time - self.first_spk_step,) + x.shape) + return bm.asarray(bm.random.rand(*shape) < x, dtype=x.dtype) + + def _zero_out(self, x): + return bm.zeros_like(x) + + +class DiffEncoder(Encoder): + """Generate spike only when the difference between two subsequent + time steps meets a threshold. + + Optionally include `off_spikes` for negative changes. + + Example:: + + >>> a = bm.array([1, 2, 2.9, 3, 3.9]) + >>> encoder = DiffEncoder(threshold=1) + >>> encoder.multi_steps(a) + Array([1., 0., 0., 0.]) + + >>> encoder = DiffEncoder(threshold=1, padding=True) + >>> encoder.multi_steps(a) + Array([0., 1., 0., 0., 0.]) + + >>> b = bm.array([1, 2, 0, 2, 2.9]) + >>> encoder = DiffEncoder(threshold=1, off_spike=True) + >>> encoder.multi_steps(b) + Array([ 1., 1., -1., 1., 0.]) + + >>> encoder = DiffEncoder(threshold=1, padding=True, off_spike=True) + >>> encoder.multi_steps(b) + Array([ 0., 1., -1., 1., 0.]) + + Args: + threshold: float. Input features with a change greater than the thresold + across one timestep will generate a spike, defaults to ``0.1``. + padding: bool. Used to change how the first time step of spikes are + measured. If ``True``, the first time step will be repeated with itself + resulting in ``0``'s for the output spikes. + If ``False``, the first time step will be padded with ``0``'s, defaults + to ``False``. + off_spike: bool. If ``True``, negative spikes for changes less than + ``-threshold``, defaults to ``False``. + """ + + def __init__( + self, + threshold: float = 0.1, + padding: bool = False, + off_spike: bool = False, + ): + super().__init__() + + self.threshold = threshold + self.padding = padding + self.off_spike = off_spike + + def single_step(self, *args, **kwargs): + raise NotImplementedError(f'{DiffEncoder.__class__.__name__} does not support single-step encoding.') + + def multi_steps(self, x): + """Encoding multistep inputs with the spiking trains. + + Args: + x: Array. The array with the shape of `(num_step, ....)`. + + Returns: + out: Array. The spike train. + """ + if self.padding: + diff = bm.diff(x, axis=0, prepend=x[:1]) + else: + diff = bm.diff(x, axis=0, prepend=bm.zeros((1,) + x.shape[1:], dtype=x.dtype)) + + if self.off_spike: + on_spk = bm.asarray(diff >= self.threshold, dtype=x.dtype) + off_spk = -bm.asarray(diff <= -self.threshold, dtype=x.dtype) + return on_spk + off_spk + + else: + return bm.asarray(diff >= self.threshold, dtype=x.dtype) diff --git a/brainpy/_src/encoding/tests/test_stateless_encoding.py b/brainpy/_src/encoding/tests/test_stateless_encoding.py new file mode 100644 index 000000000..3fec2a964 --- /dev/null +++ b/brainpy/_src/encoding/tests/test_stateless_encoding.py @@ -0,0 +1,79 @@ +import unittest +import brainpy.math as bm +import brainpy as bp + + +class TestDiffEncoder(unittest.TestCase): + def test_delta(self): + a = bm.array([1, 2, 2.9, 3, 3.9]) + encoder = bp.encoding.DiffEncoder(threshold=1) + r = encoder.multi_steps(a) + excepted = bm.asarray([1., 1., 0., 0., 0.]) + self.assertTrue(bm.allclose(r, excepted)) + + encoder = bp.encoding.DiffEncoder(threshold=1, padding=True) + r = encoder.multi_steps(a) + excepted = bm.asarray([0., 1., 0., 0., 0.]) + self.assertTrue(bm.allclose(r, excepted)) + + bm.clear_buffer_memory() + + def test_delta_off_spike(self): + b = bm.array([1, 2, 0, 2, 2.9]) + encoder = bp.encoding.DiffEncoder(threshold=1, off_spike=True) + r = encoder.multi_steps(b) + excepted = bm.asarray([1., 1., -1., 1., 0.]) + self.assertTrue(bm.allclose(r, excepted)) + + encoder = bp.encoding.DiffEncoder(threshold=1, padding=True, off_spike=True) + r = encoder.multi_steps(b) + excepted = bm.asarray([0., 1., -1., 1., 0.]) + self.assertTrue(bm.allclose(r, excepted)) + + bm.clear_buffer_memory() + + +class TestLatencyEncoder(unittest.TestCase): + def test_latency(self): + a = bm.array([0.02, 0.5, 1]) + encoder = bp.encoding.LatencyEncoder(method='linear') + + r = encoder.multi_steps(a, n_time=0.5) + excepted = bm.asarray( + [[0., 0., 1.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 1., 0.], + ] + ) + self.assertTrue(bm.allclose(r, excepted)) + + r = encoder.multi_steps(a, n_time=1.0) + excepted = bm.asarray( + [[0., 0., 1.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 1., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [0., 0., 0.], + [1., 0., 0.], + ] + ) + self.assertTrue(bm.allclose(r, excepted)) + + encoder = bp.encoding.LatencyEncoder(method='linear', normalize=True) + r = encoder.multi_steps(a, n_time=0.5) + excepted = bm.asarray( + [[0., 0., 1.], + [0., 0., 0.], + [0., 1., 0.], + [0., 0., 0.], + [1., 0., 0.], + ] + ) + self.assertTrue(bm.allclose(r, excepted)) + diff --git a/brainpy/encoding.py b/brainpy/encoding.py index 4a2de0be7..b51f9d744 100644 --- a/brainpy/encoding.py +++ b/brainpy/encoding.py @@ -9,6 +9,7 @@ WeightedPhaseEncoder as WeightedPhaseEncoder, ) from brainpy._src.encoding.stateless_encoding import ( - PoissonEncoder as PoissonEncoder + PoissonEncoder as PoissonEncoder, + DiffEncoder as DiffEncoder, )