Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Nov 23, 2024
1 parent ec39076 commit 34fe1e2
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 26 deletions.
27 changes: 27 additions & 0 deletions brainpy/_src/dependency_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
__all__ = [
'import_taichi',
'raise_taichi_not_found',
'import_braintaichi',
'raise_braintaichi_not_found',
'import_numba',
'raise_numba_not_found',
'import_cupy',
Expand All @@ -20,6 +22,7 @@

numba = None
taichi = None
braintaichi = None
cupy = None
cupy_jit = None
brainpylib_cpu_ops = None
Expand All @@ -33,6 +36,10 @@
cupy_install_info = ('We need cupy. Please install cupy by pip . \n'
'For CUDA v11.2 ~ 11.8 > pip install cupy-cuda11x\n'
'For CUDA v12.x > pip install cupy-cuda12x\n')
braintaichi_install_info = ('We need braintaichi. Please install braintaichi by pip . \n'
'> pip install braintaichi -U')


os.environ["TI_LOG_LEVEL"] = "error"


Expand Down Expand Up @@ -69,6 +76,26 @@ def import_taichi(error_if_not_found=True):
def raise_taichi_not_found(*args, **kwargs):
raise ModuleNotFoundError(taichi_install_info)

def import_braintaichi(error_if_not_found=True):
"""Internal API to import braintaichi.
If braintaichi is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
otherwise it will return None.
"""
global braintaichi
if braintaichi is None:
try:
import braintaichi as braintaichi
except ModuleNotFoundError:
if error_if_not_found:
raise_braintaichi_not_found()
else:
return None
return braintaichi

def raise_braintaichi_not_found():
raise ModuleNotFoundError(braintaichi_install_info)


def import_numba(error_if_not_found=True):
"""
Expand Down
6 changes: 3 additions & 3 deletions brainpy/_src/dnn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
import jax
import jax.numpy as jnp
import numpy as np
import braintaichi as bti

from brainpy import math as bm
from brainpy._src import connect, initialize as init
from brainpy._src.context import share
from brainpy._src.dependency_check import import_taichi
from brainpy._src.dependency_check import import_taichi, import_braintaichi
from brainpy._src.dnn.base import Layer
from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP
from brainpy.check import is_initializer
Expand All @@ -21,6 +20,7 @@
from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter
from brainpy.types import ArrayType, Sharding

bti = import_braintaichi(error_if_not_found=False)
ti = import_taichi(error_if_not_found=False)

__all__ = [
Expand Down Expand Up @@ -239,7 +239,7 @@ def update(self, x):
return x


if ti is not None:
if ti is not None and bti is not None:

# @numba.njit(nogil=True, fastmath=True, parallel=False)
# def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w):
Expand Down
10 changes: 8 additions & 2 deletions brainpy/_src/math/event/csr_matmat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@

from typing import Union, Tuple

from braintaichi import event_csrmm as bt_event_csrmm

from jax import numpy as jnp

from brainpy._src.math.ndarray import Array
from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found

bti = import_braintaichi(error_if_not_found=False)

__all__ = [
'csrmm',
Expand Down Expand Up @@ -38,4 +41,7 @@ def csrmm(
C : array of shape ``(shape[1] if transpose else shape[0], cols)``
representing the matrix-matrix product product.
"""
return bt_event_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose)
if bti is None:
raise_braintaichi_not_found()

return bti.event_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose)
10 changes: 8 additions & 2 deletions brainpy/_src/math/event/csr_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
from typing import Union, Tuple

import jax
from braintaichi import event_csrmv as bt_event_csrmv

from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found

bti = import_braintaichi(error_if_not_found=False)


__all__ = [
'csrmv'
Expand Down Expand Up @@ -60,5 +64,7 @@ def csrmv(
The array of shape ``(shape[1] if transpose else shape[0],)`` representing
the matrix vector product.
"""
if bti is None:
raise_braintaichi_not_found()

return bt_event_csrmv(data, indices, indptr, events, shape=shape, transpose=transpose)
return bti.event_csrmv(data, indices, indptr, events, shape=shape, transpose=transpose)
15 changes: 11 additions & 4 deletions brainpy/_src/math/jitconn/event_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from typing import Tuple, Optional

import jax
from braintaichi import jitc_event_mv_prob_homo, jitc_event_mv_prob_uniform, jitc_event_mv_prob_normal

from brainpy._src.math.jitconn.matvec import (mv_prob_homo,
mv_prob_uniform,
mv_prob_normal)
from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found

bti = import_braintaichi(error_if_not_found=False)

__all__ = [
'event_mv_prob_homo',
Expand All @@ -27,7 +28,9 @@ def event_mv_prob_homo(
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
return jitc_event_mv_prob_homo(events, weight, conn_prob, seed,
if bti is None:
raise_braintaichi_not_found()
return bti.jitc_event_mv_prob_homo(events, weight, conn_prob, seed,
shape=shape,
transpose=transpose,
outdim_parallel=outdim_parallel)
Expand All @@ -47,7 +50,9 @@ def event_mv_prob_uniform(
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
return jitc_event_mv_prob_uniform(events, w_low, w_high, conn_prob, seed, shape=shape,
if bti is None:
raise_braintaichi_not_found()
return bti.jitc_event_mv_prob_uniform(events, w_low, w_high, conn_prob, seed, shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)


Expand All @@ -65,7 +70,9 @@ def event_mv_prob_normal(
transpose: bool = False,
outdim_parallel: bool = True,
) -> jax.Array:
return jitc_event_mv_prob_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape,
if bti is None:
raise_braintaichi_not_found()
return bti.jitc_event_mv_prob_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)


Expand Down
32 changes: 23 additions & 9 deletions brainpy/_src/math/jitconn/matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@

import jax
import numpy as np
from braintaichi import jitc_mv_prob_homo, jitc_mv_prob_uniform, jitc_mv_prob_normal, \
get_homo_weight_matrix as bt_get_homo_weight_matrix, get_uniform_weight_matrix as bt_get_uniform_weight_matrix, \
get_normal_weight_matrix as bt_get_normal_weight_matrix
from jax import numpy as jnp

from brainpy._src.math import defaults
from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import Array
from brainpy._src.math.op_register import XLACustomOp
from brainpy.errors import PackageMissingError
from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found

bti = import_braintaichi(error_if_not_found=False)

__all__ = [
'mv_prob_homo',
Expand Down Expand Up @@ -83,7 +83,10 @@ def mv_prob_homo(
out: Array, ndarray
The output of :math:`y = M @ v`.
"""
return jitc_mv_prob_homo(vector, weight, conn_prob, seed, shape=shape,
if bti is None:
raise_braintaichi_not_found()

return bti.jitc_mv_prob_homo(vector, weight, conn_prob, seed, shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)


Expand Down Expand Up @@ -148,7 +151,10 @@ def mv_prob_uniform(
out: Array, ndarray
The output of :math:`y = M @ v`.
"""
return jitc_mv_prob_uniform(vector, w_low, w_high, conn_prob, seed, shape=shape,
if bti is None:
raise_braintaichi_not_found()

return bti.jitc_mv_prob_uniform(vector, w_low, w_high, conn_prob, seed, shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)


Expand Down Expand Up @@ -213,7 +219,9 @@ def mv_prob_normal(
out: Array, ndarray
The output of :math:`y = M @ v`.
"""
return jitc_mv_prob_normal(vector, w_mu, w_sigma, conn_prob, seed, shape=shape,
if bti is None:
raise_braintaichi_not_found()
return bti.jitc_mv_prob_normal(vector, w_mu, w_sigma, conn_prob, seed, shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)


Expand Down Expand Up @@ -248,7 +256,9 @@ def get_homo_weight_matrix(
out: Array, ndarray
The connection matrix :math:`M`.
"""
return bt_get_homo_weight_matrix(weight, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)
if bti is None:
raise_braintaichi_not_found()
return bti.get_homo_weight_matrix(weight, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)


def get_uniform_weight_matrix(
Expand Down Expand Up @@ -287,7 +297,9 @@ def get_uniform_weight_matrix(
out: Array, ndarray
The weight matrix :math:`M`.
"""
return bt_get_uniform_weight_matrix(w_low, w_high, conn_prob, seed, shape=shape,
if bti is None:
raise_braintaichi_not_found()
return bti.get_uniform_weight_matrix(w_low, w_high, conn_prob, seed, shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)


Expand Down Expand Up @@ -325,7 +337,9 @@ def get_normal_weight_matrix(
out: Array, ndarray
The weight matrix :math:`M`.
"""
return bt_get_normal_weight_matrix(w_mu, w_sigma, conn_prob, seed,
if bti is None:
raise_braintaichi_not_found()
return bti.get_normal_weight_matrix(w_mu, w_sigma, conn_prob, seed,
shape=shape,
transpose=transpose, outdim_parallel=outdim_parallel)

8 changes: 6 additions & 2 deletions brainpy/_src/math/sparse/coo_mv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@

from typing import Union, Tuple

from braintaichi import coomv as bt_coomv
from jax import numpy as jnp

from brainpy._src.math.ndarray import Array
from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found

bti = import_braintaichi(error_if_not_found=False)

__all__ = [
'coomv',
Expand Down Expand Up @@ -59,8 +61,10 @@ def coomv(
An array of shape ``(shape[1] if transpose else shape[0],)`` representing
the matrix vector product.
"""
if bti is None:
raise_braintaichi_not_found()

return bt_coomv(
return bti.coomv(
data=data,
row=row,
col=col,
Expand Down
8 changes: 6 additions & 2 deletions brainpy/_src/math/sparse/csr_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

from typing import Union, Tuple

from braintaichi import csrmm as bt_csrmm
from jax import numpy as jnp

from brainpy._src.math.ndarray import Array
from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found

bti = import_braintaichi(error_if_not_found=False)

__all__ = [
'csrmm',
Expand Down Expand Up @@ -40,4 +41,7 @@ def csrmm(
C : array of shape ``(shape[1] if transpose else shape[0], cols)``
representing the matrix-matrix product.
"""
return bt_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose)
if bti is None:
raise_braintaichi_not_found()

return bti.csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose)
7 changes: 5 additions & 2 deletions brainpy/_src/math/sparse/csr_mv.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

from typing import Union, Tuple

from braintaichi import csrmv as bt_csrmv
from jax import numpy as jnp

from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found
from brainpy._src.math.ndarray import Array

bti = import_braintaichi(error_if_not_found=False)

__all__ = [
'csrmv',
Expand Down Expand Up @@ -60,6 +61,8 @@ def csrmv(
The array of shape ``(shape[1] if transpose else shape[0],)`` representing
the matrix vector product.
"""
if bti is None:
raise_braintaichi_not_found()

return bt_csrmv(data, indices, indptr, vector, shape=shape, transpose=transpose)
return bti.csrmv(data, indices, indptr, vector, shape=shape, transpose=transpose)

0 comments on commit 34fe1e2

Please sign in to comment.