Skip to content

Commit

Permalink
[encoding] upgrade encoding methods
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Aug 29, 2023
1 parent 9655cb3 commit 68bdb2f
Show file tree
Hide file tree
Showing 5 changed files with 350 additions and 108 deletions.
9 changes: 7 additions & 2 deletions brainpy/_src/encoding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,13 @@

class Encoder(BrainPyObject):
"""Base class for encoding rate values as spike trains."""
def __call__(self, *args, **kwargs):
raise NotImplementedError

def __repr__(self):
return self.__class__.__name__

def single_step(self, *args, **kwargs):
raise NotImplementedError('Please implement the function for single step encoding.')

def multi_steps(self, *args, **kwargs):
raise NotImplementedError('Encode implement the function for multiple-step encoding.')

166 changes: 101 additions & 65 deletions brainpy/_src/encoding/stateful_encoding.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# -*- coding: utf-8 -*-

import math
from typing import Union, Callable
from typing import Union, Callable, Optional

import jax
import numpy as np

import brainpy.math as bm
from brainpy import check
Expand Down Expand Up @@ -47,13 +48,10 @@ def __init__(self,
weight_fun: Callable = None):
super().__init__()

check.is_integer(num_phase, 'num_phase', min_bound=1)
check.is_float(min_val, 'min_val')
check.is_float(max_val, 'max_val')
check.is_callable(weight_fun, 'weight_fun', allow_none=True)
self.num_phase = num_phase
self.min_val = min_val
self.max_val = max_val
self.num_phase = check.is_integer(num_phase, 'num_phase', min_bound=1)
self.min_val = check.is_float(min_val, 'min_val')
self.max_val = check.is_float(max_val, 'max_val')
self.weight_fun = (lambda i: 2 ** (-(i % num_phase + 1))) if weight_fun is None else weight_fun
self.scale = (1 - self.weight_fun(self.num_phase - 1)) / (self.max_val - self.min_val)

Expand Down Expand Up @@ -88,74 +86,112 @@ def f(i):


class LatencyEncoder(Encoder):
r"""Encode the rate input as the spike train.
r"""Encode the rate input as the spike train using the latency encoding.
The latency encoder will encode ``x`` (normalized into ``[0, 1]`` according to
Use input features to determine time-to-first spike.
Expected inputs should be between 0 and 1. If not, the latency encoder will encode ``x``
(normalized into ``[0, 1]`` according to
:math:`x_{\text{normalize}} = \frac{x-\text{min_val}}{\text{max_val} - \text{min_val}}`)
to spikes whose firing time is :math:`0 \le t_f \le \text{num_period}-1`.
A larger ``x`` will cause the earlier firing time.
Parameters
----------
min_val: float
The minimal value in the given data `x`, used to the data normalization.
max_val: float
The maximum value in the given data `x`, used to the data normalization.
num_period: int
The periodic firing time step.
method: str
How to convert intensity to firing time. Currently, we support `linear` or `log`.
- If ``method='linear'``, the firing rate is calculated as
:math:`t_f(x) = (\text{num_period} - 1)(1 - x)`.
- If ``method='log'``, the firing rate is calculated as
:math:`t_f(x) = (\text{num_period} - 1) - ln(\alpha * x + 1)`,
where :math:`\alpha` satisfies :math:`t_f(1) = \text{num_period} - 1`.
Example::
>>> a = bm.array([0.02, 0.5, 1])
>>> encoder = LatencyEncoder(method='linear', normalize=True)
>>> encoder.multi_steps(a, n_time=5)
Array([[0., 0., 1.],
[0., 0., 0.],
[0., 1., 0.],
[0., 0., 0.],
[1., 0., 0.]])
Args:
min_val: float. The minimal value in the given data `x`, used to the data normalization.
max_val: float. The maximum value in the given data `x`, used to the data normalization.
method: str. How to convert intensity to firing time. Currently, we support `linear` or `log`.
- If ``method='linear'``, the firing rate is calculated as
:math:`t_f(x) = (\text{num_period} - 1)(1 - x)`.
- If ``method='log'``, the firing rate is calculated as
:math:`t_f(x) = (\text{num_period} - 1) - ln(\alpha * x + 1)`,
where :math:`\alpha` satisfies :math:`t_f(1) = \text{num_period} - 1`.
threshold: float. Input features below the threhold will fire at the
final time step unless ``clip=True`` in which case they will not
fire at all, defaults to ``0.01``.
clip: bool. Option to remove spikes from features that fall
below the threshold, defaults to ``False``.
tau: float. RC Time constant for LIF model used to calculate
firing time, defaults to ``1``.
normalize: bool. Option to normalize the latency code such that
the final spike(s) occur within num_steps, defaults to ``False``.
epsilon: float. A tiny positive value to avoid rounding errors when
using torch.arange, defaults to ``1e-7``.
"""

def __init__(self,
min_val: float,
max_val: float,
num_period: int,
method: str = 'linear'):
def __init__(
self,
min_val: float = None,
max_val: float = None,
method: str = 'log',
threshold: float = 0.01,
clip: bool = False,
tau: float = 1.,
normalize: bool = False,
first_spk_time: float = 0.,
epsilon: float = 1e-7,
):
super().__init__()

check.is_integer(num_period, 'num_period', min_bound=1)
check.is_float(min_val, 'min_val')
check.is_float(max_val, 'max_val')
assert method in ['linear', 'log']
self.num_period = num_period
self.min_val = min_val
self.max_val = max_val
if method not in ['linear', 'log']:
raise ValueError('The conversion method can only be "linear" and "log".')
self.method = method

def __call__(self, x: ArrayType, i_step: Union[int, ArrayType]):
"""Encoding function.
Parameters
----------
x: ArrayType
The input rate value.
i_step: int, ArrayType
The indices of the time step.
Returns
-------
out: ArrayType
The encoded spike train.
self.min_val = check.is_float(min_val, 'min_val', allow_none=True)
self.max_val = check.is_float(max_val, 'max_val', allow_none=True)
if threshold < 0 or threshold > 1:
raise ValueError(f"``threshold`` [{threshold}] must be between [0, 1]")
self.threshold = threshold
self.clip = clip
self.tau = tau
self.normalize = normalize
self.first_spk_time = check.is_float(first_spk_time)
self.first_spk_step = int(first_spk_time / bm.get_dt())
self.epsilon = epsilon

def single_step(self, x, i_step: int = None):
raise NotImplementedError

def multi_steps(self, data, n_time: Optional[float] = None):
"""Generate latency spikes according to the given input data.
Ensuring x in [0., 1.].
Args:
data: The rate-based input.
n_time: float. The total time to generate data. If None, use ``tau`` instead.
Returns:
out: array. The output spiking trains.
"""
_temp = self.num_period - 1.
if self.method == 'log':
alpha = math.exp(_temp) - 1.
t_f = bm.round(_temp - bm.log(alpha * x + 1.)).astype(bm.int_)
else:
t_f = bm.round(_temp * (1. - x)).astype(bm.int_)
if n_time is None:
n_time = self.tau
tau = n_time if self.normalize else self.tau
x = data
if self.min_val is not None and self.max_val is not None:
x = (x - self.min_val) / (self.max_val - self.min_val)
if self.method == 'linear':
spike_time = (tau - self.first_spk_time - bm.dt) * (1 - x) + self.first_spk_time

elif self.method == 'log':
x = bm.maximum(x, self.threshold + self.epsilon) # saturates all values below threshold.
spike_time = (tau - self.first_spk_time - bm.dt) * bm.log(x / (x - self.threshold)) + self.first_spk_time

def f(i):
return bm.as_jax(t_f == (i % self.num_period), dtype=x.dtype)

if isinstance(i_step, int):
return f(i_step)
else:
assert isinstance(i_step, (jax.Array, bm.Array))
return jax.vmap(f, i_step)
raise ValueError(f'Unsupported method: {self.method}. Only support "log" and "linear".')

if self.clip:
spike_time = bm.where(data < self.threshold, np.inf, spike_time)
spike_steps = bm.round(spike_time / bm.get_dt()).astype(int)
return bm.one_hot(spike_steps, num_classes=int(n_time / bm.get_dt()), axis=0, dtype=x.dtype)
Loading

0 comments on commit 68bdb2f

Please sign in to comment.