diff --git a/brainpy/_src/dependency_check.py b/brainpy/_src/dependency_check.py index 89eb9c17..ab3c94c2 100644 --- a/brainpy/_src/dependency_check.py +++ b/brainpy/_src/dependency_check.py @@ -6,6 +6,8 @@ __all__ = [ 'import_taichi', 'raise_taichi_not_found', + 'import_braintaichi', + 'raise_braintaichi_not_found', 'import_numba', 'raise_numba_not_found', 'import_cupy', @@ -20,6 +22,7 @@ numba = None taichi = None +braintaichi = None cupy = None cupy_jit = None brainpylib_cpu_ops = None @@ -33,6 +36,10 @@ cupy_install_info = ('We need cupy. Please install cupy by pip . \n' 'For CUDA v11.2 ~ 11.8 > pip install cupy-cuda11x\n' 'For CUDA v12.x > pip install cupy-cuda12x\n') +braintaichi_install_info = ('We need braintaichi. Please install braintaichi by pip . \n' + '> pip install braintaichi -U') + + os.environ["TI_LOG_LEVEL"] = "error" @@ -69,6 +76,26 @@ def import_taichi(error_if_not_found=True): def raise_taichi_not_found(*args, **kwargs): raise ModuleNotFoundError(taichi_install_info) +def import_braintaichi(error_if_not_found=True): + """Internal API to import braintaichi. + + If braintaichi is not found, it will raise a ModuleNotFoundError if error_if_not_found is True, + otherwise it will return None. + """ + global braintaichi + if braintaichi is None: + try: + import braintaichi as braintaichi + except ModuleNotFoundError: + if error_if_not_found: + raise_braintaichi_not_found() + else: + return None + return braintaichi + +def raise_braintaichi_not_found(): + raise ModuleNotFoundError(braintaichi_install_info) + def import_numba(error_if_not_found=True): """ diff --git a/brainpy/_src/dnn/linear.py b/brainpy/_src/dnn/linear.py index c923454c..5be3e89a 100644 --- a/brainpy/_src/dnn/linear.py +++ b/brainpy/_src/dnn/linear.py @@ -7,12 +7,11 @@ import jax import jax.numpy as jnp import numpy as np -import braintaichi as bti from brainpy import math as bm from brainpy._src import connect, initialize as init from brainpy._src.context import share -from brainpy._src.dependency_check import import_taichi +from brainpy._src.dependency_check import import_taichi, import_braintaichi from brainpy._src.dnn.base import Layer from brainpy._src.mixin import SupportOnline, SupportOffline, SupportSTDP from brainpy.check import is_initializer @@ -21,6 +20,7 @@ from brainpy.initialize import XavierNormal, ZeroInit, Initializer, parameter from brainpy.types import ArrayType, Sharding +bti = import_braintaichi(error_if_not_found=False) ti = import_taichi(error_if_not_found=False) __all__ = [ @@ -239,7 +239,7 @@ def update(self, x): return x -if ti is not None: +if ti is not None and bti is not None: # @numba.njit(nogil=True, fastmath=True, parallel=False) # def _cpu_dense_on_post(weight, spike, trace, w_min, w_max, out_w): diff --git a/brainpy/_src/math/event/csr_matmat.py b/brainpy/_src/math/event/csr_matmat.py index aaf8695c..b78afad7 100644 --- a/brainpy/_src/math/event/csr_matmat.py +++ b/brainpy/_src/math/event/csr_matmat.py @@ -3,10 +3,13 @@ from typing import Union, Tuple -from braintaichi import event_csrmm as bt_event_csrmm + from jax import numpy as jnp from brainpy._src.math.ndarray import Array +from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found + +bti = import_braintaichi(error_if_not_found=False) __all__ = [ 'csrmm', @@ -38,4 +41,7 @@ def csrmm( C : array of shape ``(shape[1] if transpose else shape[0], cols)`` representing the matrix-matrix product product. """ - return bt_event_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) + if bti is None: + raise_braintaichi_not_found() + + return bti.event_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) diff --git a/brainpy/_src/math/event/csr_matvec.py b/brainpy/_src/math/event/csr_matvec.py index 7cd527c3..3969ee6b 100644 --- a/brainpy/_src/math/event/csr_matvec.py +++ b/brainpy/_src/math/event/csr_matvec.py @@ -13,7 +13,11 @@ from typing import Union, Tuple import jax -from braintaichi import event_csrmv as bt_event_csrmv + +from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found + +bti = import_braintaichi(error_if_not_found=False) + __all__ = [ 'csrmv' @@ -60,5 +64,7 @@ def csrmv( The array of shape ``(shape[1] if transpose else shape[0],)`` representing the matrix vector product. """ + if bti is None: + raise_braintaichi_not_found() - return bt_event_csrmv(data, indices, indptr, events, shape=shape, transpose=transpose) + return bti.event_csrmv(data, indices, indptr, events, shape=shape, transpose=transpose) diff --git a/brainpy/_src/math/jitconn/event_matvec.py b/brainpy/_src/math/jitconn/event_matvec.py index 2f22145b..80bba29b 100644 --- a/brainpy/_src/math/jitconn/event_matvec.py +++ b/brainpy/_src/math/jitconn/event_matvec.py @@ -3,12 +3,13 @@ from typing import Tuple, Optional import jax -from braintaichi import jitc_event_mv_prob_homo, jitc_event_mv_prob_uniform, jitc_event_mv_prob_normal from brainpy._src.math.jitconn.matvec import (mv_prob_homo, mv_prob_uniform, mv_prob_normal) +from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found +bti = import_braintaichi(error_if_not_found=False) __all__ = [ 'event_mv_prob_homo', @@ -27,7 +28,9 @@ def event_mv_prob_homo( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - return jitc_event_mv_prob_homo(events, weight, conn_prob, seed, + if bti is None: + raise_braintaichi_not_found() + return bti.jitc_event_mv_prob_homo(events, weight, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) @@ -47,7 +50,9 @@ def event_mv_prob_uniform( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - return jitc_event_mv_prob_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, + if bti is None: + raise_braintaichi_not_found() + return bti.jitc_event_mv_prob_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) @@ -65,7 +70,9 @@ def event_mv_prob_normal( transpose: bool = False, outdim_parallel: bool = True, ) -> jax.Array: - return jitc_event_mv_prob_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, + if bti is None: + raise_braintaichi_not_found() + return bti.jitc_event_mv_prob_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) diff --git a/brainpy/_src/math/jitconn/matvec.py b/brainpy/_src/math/jitconn/matvec.py index 5b318e72..4d4dd25a 100644 --- a/brainpy/_src/math/jitconn/matvec.py +++ b/brainpy/_src/math/jitconn/matvec.py @@ -4,9 +4,6 @@ import jax import numpy as np -from braintaichi import jitc_mv_prob_homo, jitc_mv_prob_uniform, jitc_mv_prob_normal, \ - get_homo_weight_matrix as bt_get_homo_weight_matrix, get_uniform_weight_matrix as bt_get_uniform_weight_matrix, \ - get_normal_weight_matrix as bt_get_normal_weight_matrix from jax import numpy as jnp from brainpy._src.math import defaults @@ -14,6 +11,9 @@ from brainpy._src.math.ndarray import Array from brainpy._src.math.op_register import XLACustomOp from brainpy.errors import PackageMissingError +from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found + +bti = import_braintaichi(error_if_not_found=False) __all__ = [ 'mv_prob_homo', @@ -83,7 +83,10 @@ def mv_prob_homo( out: Array, ndarray The output of :math:`y = M @ v`. """ - return jitc_mv_prob_homo(vector, weight, conn_prob, seed, shape=shape, + if bti is None: + raise_braintaichi_not_found() + + return bti.jitc_mv_prob_homo(vector, weight, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) @@ -148,7 +151,10 @@ def mv_prob_uniform( out: Array, ndarray The output of :math:`y = M @ v`. """ - return jitc_mv_prob_uniform(vector, w_low, w_high, conn_prob, seed, shape=shape, + if bti is None: + raise_braintaichi_not_found() + + return bti.jitc_mv_prob_uniform(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) @@ -213,7 +219,9 @@ def mv_prob_normal( out: Array, ndarray The output of :math:`y = M @ v`. """ - return jitc_mv_prob_normal(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, + if bti is None: + raise_braintaichi_not_found() + return bti.jitc_mv_prob_normal(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) @@ -248,7 +256,9 @@ def get_homo_weight_matrix( out: Array, ndarray The connection matrix :math:`M`. """ - return bt_get_homo_weight_matrix(weight, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + if bti is None: + raise_braintaichi_not_found() + return bti.get_homo_weight_matrix(weight, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) def get_uniform_weight_matrix( @@ -287,7 +297,9 @@ def get_uniform_weight_matrix( out: Array, ndarray The weight matrix :math:`M`. """ - return bt_get_uniform_weight_matrix(w_low, w_high, conn_prob, seed, shape=shape, + if bti is None: + raise_braintaichi_not_found() + return bti.get_uniform_weight_matrix(w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) @@ -325,7 +337,9 @@ def get_normal_weight_matrix( out: Array, ndarray The weight matrix :math:`M`. """ - return bt_get_normal_weight_matrix(w_mu, w_sigma, conn_prob, seed, + if bti is None: + raise_braintaichi_not_found() + return bti.get_normal_weight_matrix(w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) diff --git a/brainpy/_src/math/sparse/coo_mv.py b/brainpy/_src/math/sparse/coo_mv.py index 439a99ae..c9a46ff6 100644 --- a/brainpy/_src/math/sparse/coo_mv.py +++ b/brainpy/_src/math/sparse/coo_mv.py @@ -3,10 +3,12 @@ from typing import Union, Tuple -from braintaichi import coomv as bt_coomv from jax import numpy as jnp from brainpy._src.math.ndarray import Array +from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found + +bti = import_braintaichi(error_if_not_found=False) __all__ = [ 'coomv', @@ -59,8 +61,10 @@ def coomv( An array of shape ``(shape[1] if transpose else shape[0],)`` representing the matrix vector product. """ + if bti is None: + raise_braintaichi_not_found() - return bt_coomv( + return bti.coomv( data=data, row=row, col=col, diff --git a/brainpy/_src/math/sparse/csr_mm.py b/brainpy/_src/math/sparse/csr_mm.py index ca643838..4d5b0d6c 100644 --- a/brainpy/_src/math/sparse/csr_mm.py +++ b/brainpy/_src/math/sparse/csr_mm.py @@ -3,11 +3,12 @@ from typing import Union, Tuple -from braintaichi import csrmm as bt_csrmm from jax import numpy as jnp from brainpy._src.math.ndarray import Array +from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found +bti = import_braintaichi(error_if_not_found=False) __all__ = [ 'csrmm', @@ -40,4 +41,7 @@ def csrmm( C : array of shape ``(shape[1] if transpose else shape[0], cols)`` representing the matrix-matrix product. """ - return bt_csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) \ No newline at end of file + if bti is None: + raise_braintaichi_not_found() + + return bti.csrmm(data, indices, indptr, matrix, shape=shape, transpose=transpose) \ No newline at end of file diff --git a/brainpy/_src/math/sparse/csr_mv.py b/brainpy/_src/math/sparse/csr_mv.py index 63a8dcec..c39744bb 100644 --- a/brainpy/_src/math/sparse/csr_mv.py +++ b/brainpy/_src/math/sparse/csr_mv.py @@ -3,11 +3,12 @@ from typing import Union, Tuple -from braintaichi import csrmv as bt_csrmv from jax import numpy as jnp +from brainpy._src.dependency_check import import_braintaichi, raise_braintaichi_not_found from brainpy._src.math.ndarray import Array +bti = import_braintaichi(error_if_not_found=False) __all__ = [ 'csrmv', @@ -60,6 +61,8 @@ def csrmv( The array of shape ``(shape[1] if transpose else shape[0],)`` representing the matrix vector product. """ + if bti is None: + raise_braintaichi_not_found() - return bt_csrmv(data, indices, indptr, vector, shape=shape, transpose=transpose) + return bti.csrmv(data, indices, indptr, vector, shape=shape, transpose=transpose)