Skip to content

Commit

Permalink
[compatibility] more operators in pytorch and tensorflow
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jan 29, 2023
1 parent 33db354 commit a52de70
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 22 deletions.
21 changes: 20 additions & 1 deletion brainpy/_src/math/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,29 @@ def _compatible_with_brainpy_array(fun: Callable):
@functools.wraps(fun)
def new_fun(*args, **kwargs):
args = tree_map(_as_jax_array_, args, is_leaf=_is_leaf)
out = None
if len(kwargs):
# compatible with PyTorch syntax
if 'dim' in kwargs:
kwargs['axis'] = kwargs.pop('dim')
# compatible with PyTorch syntax
if 'keepdim' in kwargs:
kwargs['keep_dims'] = kwargs.pop('keepdim')
# compatible with TensorFlow syntax
if 'keepdims' in kwargs:
kwargs['keep_dims'] = kwargs.pop('keepdims')
# compatible with NumPy/PyTorch syntax
if 'out' in kwargs:
out = kwargs.get('out')
if not isinstance(out, Array):
raise TypeError(f'"out" must be an instance of brainpy Array. While we got {type(out)}')
# format
kwargs = tree_map(_as_jax_array_, kwargs, is_leaf=_is_leaf)
r = fun(*args, **kwargs)
return tree_map(_return, r)
if out is None:
return tree_map(_return, r)
else:
out.value = r

new_fun.__doc__ = getattr(fun, "__doc__", None)

Expand Down
49 changes: 45 additions & 4 deletions brainpy/_src/math/compat_numpy.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,24 @@
# -*- coding: utf-8 -*-

from typing import (Union, Any, Protocol)

import jax.numpy as jnp
import numpy as np
from jax.tree_util import tree_map
from jax.tree_util import tree_flatten, tree_unflatten
from jax.tree_util import tree_map

from ._utils import _compatible_with_brainpy_array, _as_jax_array_
from .arrayinterporate import *
from .ndarray import Array


class SupportsDType(Protocol):
@property
def dtype(self) -> np.dtype: ...


DTypeLike = Union[Any, str, np.dtype, SupportsDType]

__all__ = [
'full', 'full_like', 'eye', 'identity', 'diag', 'tri', 'tril', 'triu',
'empty', 'empty_like', 'ones', 'ones_like', 'zeros', 'zeros_like',
Expand Down Expand Up @@ -99,10 +109,39 @@

]


_min = min
_max = max

# def concatenate(arrays: Union[np.ndarray, Array, Sequence[Array]],
# axis: Optional[int] = None,
# dim: Optional[int] = None,
# dtype: Optional[DTypeLike] = None) -> Array:
# """Join a sequence of arrays along an existing axis.
#
#
# Parameters
# ----------
# a1, a2, ... : sequence of array_like
# The arrays must have the same shape, except in the dimension
# corresponding to `axis` (the first, by default).
# axis : int, optional
# The axis along which the arrays will be joined. If axis is None,
# arrays are flattened before use. Default is 0.
# dtype : str or dtype
# If provided, the destination array will have this dtype. Cannot be
# provided together with `out`.
#
# Returns
# -------
# res : ndarray
# The concatenated array.
# """
# axis = one_of(0, axis, dim, ['axis', 'dim'])
# r = jnp.concatenate(tree_map(_as_jax_array_, arrays, is_leaf=_is_leaf),
# axis=axis,
# dtype=dtype)
# return _return(r)


def fill_diagonal(a, val, inplace=True):
if a.ndim < 2:
Expand All @@ -112,13 +151,14 @@ def fill_diagonal(a, val, inplace=True):
'it requires a brainpy Array. If you want to disable '
'inplace updating, use ``fill_diagonal(inplace=False)``.')
val = val.value if isinstance(val, Array) else val
i, j = jnp.diag_indices(min(a.shape[-2:]))
i, j = jnp.diag_indices(_min(a.shape[-2:]))
r = as_jax(a).at[..., i, j].set(val)
if inplace:
a.value = r
else:
return r


def zeros(shape, dtype=None):
return Array(jnp.zeros(shape, dtype=dtype))

Expand Down Expand Up @@ -191,6 +231,7 @@ def logspace(*args, **kwargs):
kwargs = {k: _as_jax_array_(v) for k, v in kwargs.items()}
return Array(jnp.logspace(*args, **kwargs))


def asanyarray(a, dtype=None, order=None):
return asarray(a, dtype=dtype, order=order)

Expand Down Expand Up @@ -612,7 +653,7 @@ def common_type(*arrays):
p = array_precision.get(t, None)
if p is None:
raise TypeError("can't get common type for non-numeric array")
precision = max(precision, p)
precision = _max(precision, p)
if is_complex:
return array_type[1][precision]
else:
Expand Down
32 changes: 32 additions & 0 deletions brainpy/_src/math/compat_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,25 @@
import numpy as np

from .ndarray import Array, _as_jax_array_
from .compat_numpy import (
concatenate,
)

__all__ = [
'Tensor',
'flatten',
'cat',

# data types
'bfloat16', 'half', 'float', 'double', 'cfloat', 'cdouble', 'short', 'int', 'long', 'bool'
]



Tensor = Array
cat = concatenate


def flatten(input: Union[jax.Array, Array],
start_dim: Optional[int] = None,
end_dim: Optional[int] = None) -> jax.Array:
Expand Down Expand Up @@ -56,3 +69,22 @@ def flatten(input: Union[jax.Array, Array],
new_shape = shape[:start_dim] + (np.prod(shape[start_dim: end_dim], dtype=int), ) + shape[end_dim:]
return jnp.reshape(input, new_shape)

# data types
bfloat16 = jnp.bfloat16
half = jnp.float16
float = jnp.float32
double = jnp.float64
cfloat = jnp.complex64
cdouble = jnp.complex128
short = jnp.int16
int = jnp.int32
long = jnp.int64
bool = jnp.bool_
# missing types #
# chalf = np.complex32
# quint8 = jnp.quint8
# qint8 = jnp.qint8
# qint32 = jnp.qint32
# quint4x2 = jnp.quint4x2


36 changes: 20 additions & 16 deletions brainpy/_src/math/compat_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,31 @@
import jax.ops

from .ndarray import _return, _as_jax_array_
from .compat_numpy import prod, min, sum, all, any, mean, std, var
from .compat_numpy import (
prod, min, sum, all, any, mean, std, var, concatenate, clip
)

__all__ = [
'reduce_sum', 'reduce_max', 'reduce_min', 'reduce_mean', 'reduce_all',
'reduce_any', 'reduce_logsumexp', 'reduce_prod', 'reduce_std', 'reduce_variance',
'reduce_euclidean_norm',
'unsorted_segment_sqrt_n', 'segment_mean', 'unsorted_segment_sum',
'unsorted_segment_prod', 'unsorted_segment_max', 'unsorted_segment_min',
'unsorted_segment_mean',
'concat',
'reduce_sum', 'reduce_max', 'reduce_min', 'reduce_mean', 'reduce_all', 'reduce_any',
'reduce_logsumexp', 'reduce_prod', 'reduce_std', 'reduce_variance', 'reduce_euclidean_norm',
'unsorted_segment_sqrt_n', 'segment_mean', 'unsorted_segment_sum', 'unsorted_segment_prod',
'unsorted_segment_max', 'unsorted_segment_min', 'unsorted_segment_mean',
'clip_by_value',
]


reduce_prod = prod
reduce_sum = sum
reduce_all = all
reduce_any = any
reduce_min = min
reduce_mean = mean
reduce_std = std
reduce_variance = var
concat = concatenate
clip_by_value = clip

def reduce_logsumexp(input_tensor, axis=None, keep_dims=False):
"""Computes log(sum(exp(elements across dimensions of a tensor))).
Expand Down Expand Up @@ -95,15 +108,6 @@ def reduce_max(input_tensor, axis=None, keep_dims=False):
return _return(jnp.max(_as_jax_array_(input_tensor), axis=axis, keep_dims=keep_dims))


reduce_prod = prod
reduce_sum = sum
reduce_all = all
reduce_any = any
reduce_min = min
reduce_mean = mean
reduce_std = std
reduce_variance = var



def segment_mean(data, segment_ids):
Expand Down
17 changes: 16 additions & 1 deletion brainpy/_src/tools/others.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
import collections.abc
import _thread as thread
import threading
from typing import Optional, Tuple, Callable, Union, Sequence, TypeVar
from typing import Optional, Tuple, Callable, Union, Sequence, TypeVar, Any

import numpy as np
from jax import lax
from jax.experimental import host_callback
from tqdm.auto import tqdm

__all__ = [
'one_of',
'replicate',
'not_customized',
'to_size',
Expand All @@ -20,6 +21,20 @@
]


def one_of(default: Any, *choices, names: Sequence[str] =None):
names = [f'arg{i}' for i in range(len(choices))] if names is None else names
res = default
has_chosen = False
for c in choices:
if c is not None:
if has_chosen:
raise ValueError(f'Provide one of {names}, but we got {list(zip(choices, names))}')
else:
has_chosen = True
res = c
return res


T = TypeVar('T')


Expand Down
13 changes: 13 additions & 0 deletions brainpy/math/compat_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,17 @@

from brainpy._src.math.compat_pytorch import (
Tensor as Tensor,
flatten as flatten,
cat as cat,

bfloat16 as bfloat16,
half as half,
float as float,
double as double,
cfloat as cfloat,
cdouble as cdouble,
short as short,
int as int,
long as long,
bool as bool,
)
2 changes: 2 additions & 0 deletions brainpy/math/compat_tensorflow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

from brainpy._src.math.compat_tensorflow import (
concat as concat,
reduce_sum as reduce_sum,
reduce_max as reduce_max,
reduce_min as reduce_min,
Expand All @@ -18,5 +19,6 @@
unsorted_segment_max as unsorted_segment_max,
unsorted_segment_min as unsorted_segment_min,
unsorted_segment_mean as unsorted_segment_mean,
clip_by_value as clip_by_value,
)

0 comments on commit a52de70

Please sign in to comment.