Skip to content

Commit

Permalink
Merge pull request #468 from chaoming0625/master
Browse files Browse the repository at this point in the history
Update operators for compatible with ``brainpylib>=0.1.10``
  • Loading branch information
chaoming0625 authored Sep 6, 2023
2 parents aead878 + e67f142 commit 21848a1
Show file tree
Hide file tree
Showing 20 changed files with 589 additions and 448 deletions.
2 changes: 1 addition & 1 deletion brainpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

__version__ = "2.4.4.post2"
__version__ = "2.4.4.post3"

# fundamental supporting modules
from brainpy import errors, check, tools
Expand Down
104 changes: 54 additions & 50 deletions brainpy/_src/connect/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
'SUPPORTED_SYN_STRUCTURE',

# the connection dtypes
'set_default_dtype', 'MAT_DTYPE', 'IDX_DTYPE',
'set_default_dtype', 'MAT_DTYPE', 'IDX_DTYPE', 'get_idx_type',

# brainpy_object class
'Connector', 'TwoEndConnector', 'OneEndConnector',
Expand Down Expand Up @@ -59,6 +59,10 @@
IDX_DTYPE = jnp.int32


def get_idx_type():
return IDX_DTYPE


def set_default_dtype(mat_dtype=None, idx_dtype=None):
"""Set the default dtype.
Expand Down Expand Up @@ -247,44 +251,44 @@ def _return_by_csr(self, structures, csr: tuple, all_data: dict):

if (PRE_IDS in structures) and (PRE_IDS not in all_data):
pre_ids = np.repeat(np.arange(self.pre_num), np.diff(indptr))
all_data[PRE_IDS] = bm.as_jax(pre_ids, dtype=IDX_DTYPE)
all_data[PRE_IDS] = bm.as_jax(pre_ids, dtype=get_idx_type())

if (POST_IDS in structures) and (POST_IDS not in all_data):
all_data[POST_IDS] = bm.as_jax(indices, dtype=IDX_DTYPE)
all_data[POST_IDS] = bm.as_jax(indices, dtype=get_idx_type())

if (COO in structures) and (COO not in all_data):
pre_ids = np.repeat(np.arange(self.pre_num), np.diff(indptr))
all_data[COO] = (bm.as_jax(pre_ids, dtype=IDX_DTYPE),
bm.as_jax(indices, dtype=IDX_DTYPE))
all_data[COO] = (bm.as_jax(pre_ids, dtype=get_idx_type()),
bm.as_jax(indices, dtype=get_idx_type()))

if (PRE2POST in structures) and (PRE2POST not in all_data):
all_data[PRE2POST] = (bm.as_jax(indices, dtype=IDX_DTYPE),
bm.as_jax(indptr, dtype=IDX_DTYPE))
all_data[PRE2POST] = (bm.as_jax(indices, dtype=get_idx_type()),
bm.as_jax(indptr, dtype=get_idx_type()))

if (CSR in structures) and (CSR not in all_data):
all_data[CSR] = (bm.as_jax(indices, dtype=IDX_DTYPE),
bm.as_jax(indptr, dtype=IDX_DTYPE))
all_data[CSR] = (bm.as_jax(indices, dtype=get_idx_type()),
bm.as_jax(indptr, dtype=get_idx_type()))

if (POST2PRE in structures) and (POST2PRE not in all_data):
indc, indptrc = csr2csc((indices, indptr), self.post_num)
all_data[POST2PRE] = (bm.as_jax(indc, dtype=IDX_DTYPE),
bm.as_jax(indptrc, dtype=IDX_DTYPE))
all_data[POST2PRE] = (bm.as_jax(indc, dtype=get_idx_type()),
bm.as_jax(indptrc, dtype=get_idx_type()))

if (CSC in structures) and (CSC not in all_data):
indc, indptrc = csr2csc((indices, indptr), self.post_num)
all_data[CSC] = (bm.as_jax(indc, dtype=IDX_DTYPE),
bm.as_jax(indptrc, dtype=IDX_DTYPE))
all_data[CSC] = (bm.as_jax(indc, dtype=get_idx_type()),
bm.as_jax(indptrc, dtype=get_idx_type()))

if (PRE2SYN in structures) and (PRE2SYN not in all_data):
syn_seq = np.arange(indices.size, dtype=IDX_DTYPE)
all_data[PRE2SYN] = (bm.as_jax(syn_seq, dtype=IDX_DTYPE),
bm.as_jax(indptr, dtype=IDX_DTYPE))
syn_seq = np.arange(indices.size, dtype=get_idx_type())
all_data[PRE2SYN] = (bm.as_jax(syn_seq, dtype=get_idx_type()),
bm.as_jax(indptr, dtype=get_idx_type()))

if (POST2SYN in structures) and (POST2SYN not in all_data):
syn_seq = np.arange(indices.size, dtype=IDX_DTYPE)
syn_seq = np.arange(indices.size, dtype=get_idx_type())
_, indptrc, syn_seqc = csr2csc((indices, indptr), self.post_num, syn_seq)
all_data[POST2SYN] = (bm.as_jax(syn_seqc, dtype=IDX_DTYPE),
bm.as_jax(indptrc, dtype=IDX_DTYPE))
all_data[POST2SYN] = (bm.as_jax(syn_seqc, dtype=get_idx_type()),
bm.as_jax(indptrc, dtype=get_idx_type()))

def _return_by_coo(self, structures, coo: tuple, all_data: dict):
pre_ids, post_ids = coo
Expand All @@ -293,24 +297,24 @@ def _return_by_coo(self, structures, coo: tuple, all_data: dict):
all_data[CONN_MAT] = bm.as_jax(coo2mat(coo, self.pre_num, self.post_num), dtype=MAT_DTYPE)

if (PRE_IDS in structures) and (PRE_IDS not in all_data):
all_data[PRE_IDS] = bm.as_jax(pre_ids, dtype=IDX_DTYPE)
all_data[PRE_IDS] = bm.as_jax(pre_ids, dtype=get_idx_type())

if (POST_IDS in structures) and (POST_IDS not in all_data):
all_data[POST_IDS] = bm.as_jax(post_ids, dtype=IDX_DTYPE)
all_data[POST_IDS] = bm.as_jax(post_ids, dtype=get_idx_type())

if (COO in structures) and (COO not in all_data):
all_data[COO] = (bm.as_jax(pre_ids, dtype=IDX_DTYPE),
bm.as_jax(post_ids, dtype=IDX_DTYPE))
all_data[COO] = (bm.as_jax(pre_ids, dtype=get_idx_type()),
bm.as_jax(post_ids, dtype=get_idx_type()))

if CSC in structures and CSC not in all_data:
csc = coo2csc(coo, self.post_num)
all_data[CSC] = (bm.as_jax(csc[0], dtype=IDX_DTYPE),
bm.as_jax(csc[1], dtype=IDX_DTYPE))
all_data[CSC] = (bm.as_jax(csc[0], dtype=get_idx_type()),
bm.as_jax(csc[1], dtype=get_idx_type()))

if POST2PRE in structures and POST2PRE not in all_data:
csc = coo2csc(coo, self.post_num)
all_data[POST2PRE] = (bm.as_jax(csc[0], dtype=IDX_DTYPE),
bm.as_jax(csc[1], dtype=IDX_DTYPE))
all_data[POST2PRE] = (bm.as_jax(csc[0], dtype=get_idx_type()),
bm.as_jax(csc[1], dtype=get_idx_type()))

if (len([s for s in structures
if s not in [CONN_MAT, PRE_IDS, POST_IDS,
Expand Down Expand Up @@ -350,8 +354,8 @@ def _make_returns(self, structures, conn_data):
# "csr" structure
if csr is not None:
if (PRE2POST in structures) and (PRE2POST not in all_data):
all_data[PRE2POST] = (bm.as_jax(csr[0], dtype=IDX_DTYPE),
bm.as_jax(csr[1], dtype=IDX_DTYPE))
all_data[PRE2POST] = (bm.as_jax(csr[0], dtype=get_idx_type()),
bm.as_jax(csr[1], dtype=get_idx_type()))
self._return_by_csr(structures, csr=csr, all_data=all_data)

# "mat" structure
Expand All @@ -364,9 +368,9 @@ def _make_returns(self, structures, conn_data):
# "coo" structure
if coo is not None:
if (PRE_IDS in structures) and (PRE_IDS not in structures):
all_data[PRE_IDS] = bm.as_jax(coo[0], dtype=IDX_DTYPE)
all_data[PRE_IDS] = bm.as_jax(coo[0], dtype=get_idx_type())
if (POST_IDS in structures) and (POST_IDS not in structures):
all_data[POST_IDS] = bm.as_jax(coo[1], dtype=IDX_DTYPE)
all_data[POST_IDS] = bm.as_jax(coo[1], dtype=get_idx_type())
self._return_by_coo(structures, coo=coo, all_data=all_data)

# return
Expand Down Expand Up @@ -416,34 +420,34 @@ def require(self, *structures):
if len(structures) == 1:
if PRE2POST in structures and _has_csr_imp:
r = self.build_csr()
return bm.as_jax(r[0], dtype=IDX_DTYPE), bm.as_jax(r[1], dtype=IDX_DTYPE)
return bm.as_jax(r[0], dtype=get_idx_type()), bm.as_jax(r[1], dtype=get_idx_type())
elif CSR in structures and _has_csr_imp:
r = self.build_csr()
return bm.as_jax(r[0], dtype=IDX_DTYPE), bm.as_jax(r[1], dtype=IDX_DTYPE)
return bm.as_jax(r[0], dtype=get_idx_type()), bm.as_jax(r[1], dtype=get_idx_type())
elif CONN_MAT in structures and _has_mat_imp:
return bm.as_jax(self.build_mat(), dtype=MAT_DTYPE)
elif PRE_IDS in structures and _has_coo_imp:
return bm.as_jax(self.build_coo()[0], dtype=IDX_DTYPE)
return bm.as_jax(self.build_coo()[0], dtype=get_idx_type())
elif POST_IDS in structures and _has_coo_imp:
return bm.as_jax(self.build_coo()[1], dtype=IDX_DTYPE)
return bm.as_jax(self.build_coo()[1], dtype=get_idx_type())
elif COO in structures and _has_coo_imp:
r = self.build_coo()
return bm.as_jax(r[0], dtype=IDX_DTYPE), bm.as_jax(r[1], dtype=IDX_DTYPE)
return bm.as_jax(r[0], dtype=get_idx_type()), bm.as_jax(r[1], dtype=get_idx_type())

elif len(structures) == 2:
if (PRE_IDS in structures and POST_IDS in structures and _has_coo_imp):
r = self.build_coo()
if structures[0] == PRE_IDS:
return bm.as_jax(r[0], dtype=IDX_DTYPE), bm.as_jax(r[1], dtype=IDX_DTYPE)
return bm.as_jax(r[0], dtype=get_idx_type()), bm.as_jax(r[1], dtype=get_idx_type())
else:
return bm.as_jax(r[1], dtype=IDX_DTYPE), bm.as_jax(r[0], dtype=IDX_DTYPE)
return bm.as_jax(r[1], dtype=get_idx_type()), bm.as_jax(r[0], dtype=get_idx_type())

if ((CSR in structures or PRE2POST in structures)
and _has_csr_imp and COO in structures and _has_coo_imp):
csr = self.build_csr()
csr = (bm.as_jax(csr[0], dtype=IDX_DTYPE), bm.as_jax(csr[1], dtype=IDX_DTYPE))
csr = (bm.as_jax(csr[0], dtype=get_idx_type()), bm.as_jax(csr[1], dtype=get_idx_type()))
coo = self.build_coo()
coo = (bm.as_jax(coo[0], dtype=IDX_DTYPE), bm.as_jax(coo[1], dtype=IDX_DTYPE))
coo = (bm.as_jax(coo[0], dtype=get_idx_type()), bm.as_jax(coo[1], dtype=get_idx_type()))
if structures[0] == COO:
return coo, csr
else:
Expand All @@ -452,7 +456,7 @@ def require(self, *structures):
if ((CSR in structures or PRE2POST in structures)
and _has_csr_imp and CONN_MAT in structures and _has_mat_imp):
csr = self.build_csr()
csr = (bm.as_jax(csr[0], dtype=IDX_DTYPE), bm.as_jax(csr[1], dtype=IDX_DTYPE))
csr = (bm.as_jax(csr[0], dtype=get_idx_type()), bm.as_jax(csr[1], dtype=get_idx_type()))
mat = bm.as_jax(self.build_mat(), dtype=MAT_DTYPE)
if structures[0] == CONN_MAT:
return mat, csr
Expand All @@ -461,7 +465,7 @@ def require(self, *structures):

if (COO in structures and _has_coo_imp and CONN_MAT in structures and _has_mat_imp):
coo = self.build_coo()
coo = (bm.as_jax(coo[0], dtype=IDX_DTYPE), bm.as_jax(coo[1], dtype=IDX_DTYPE))
coo = (bm.as_jax(coo[0], dtype=get_idx_type()), bm.as_jax(coo[1], dtype=get_idx_type()))
mat = bm.as_jax(self.build_mat(), dtype=MAT_DTYPE)
if structures[0] == COO:
return coo, mat
Expand Down Expand Up @@ -612,7 +616,7 @@ def mat2coo(dense):
pre_ids, post_ids = onp.where(dense > 0)
else:
pre_ids, post_ids = jnp.where(bm.as_jax(dense) > 0)
return pre_ids.astype(dtype=IDX_DTYPE), post_ids.astype(dtype=IDX_DTYPE)
return pre_ids.astype(dtype=get_idx_type()), post_ids.astype(dtype=get_idx_type())


def mat2csc(dense):
Expand Down Expand Up @@ -686,7 +690,7 @@ def coo2csr(coo, num_pre):
final_pre_count = bm.as_jax(final_pre_count)
indptr = final_pre_count.cumsum()
indptr = onp.insert(indptr, 0, 0)
return indices.astype(IDX_DTYPE), indptr.astype(IDX_DTYPE)
return indices.astype(get_idx_type()), indptr.astype(get_idx_type())


def coo2csc(coo, post_num, data=None):
Expand All @@ -695,31 +699,31 @@ def coo2csc(coo, post_num, data=None):
if isinstance(indices, onp.ndarray):
# to maintain the original order of the elements with the same value
sort_ids = onp.argsort(indices)
pre_ids_new = onp.asarray(pre_ids[sort_ids], dtype=IDX_DTYPE)
pre_ids_new = onp.asarray(pre_ids[sort_ids], dtype=get_idx_type())

unique_post_ids, count = onp.unique(indices, return_counts=True)
post_count = onp.zeros(post_num, dtype=IDX_DTYPE)
post_count = onp.zeros(post_num, dtype=get_idx_type())
post_count[unique_post_ids] = count

indptr_new = post_count.cumsum()
indptr_new = onp.insert(indptr_new, 0, 0)
indptr_new = onp.asarray(indptr_new, dtype=IDX_DTYPE)
indptr_new = onp.asarray(indptr_new, dtype=get_idx_type())

else:
pre_ids = bm.as_jax(pre_ids)
indices = bm.as_jax(indices)

# to maintain the original order of the elements with the same value
sort_ids = jnp.argsort(indices)
pre_ids_new = jnp.asarray(pre_ids[sort_ids], dtype=IDX_DTYPE)
pre_ids_new = jnp.asarray(pre_ids[sort_ids], dtype=get_idx_type())

unique_post_ids, count = jnp.unique(indices, return_counts=True)
post_count = bm.zeros(post_num, dtype=IDX_DTYPE)
post_count = bm.zeros(post_num, dtype=get_idx_type())
post_count[unique_post_ids] = count

indptr_new = post_count.value.cumsum()
indptr_new = jnp.insert(indptr_new, 0, 0)
indptr_new = jnp.asarray(indptr_new, dtype=IDX_DTYPE)
indptr_new = jnp.asarray(indptr_new, dtype=get_idx_type())

if data is None:
return pre_ids_new, indptr_new
Expand Down
Loading

0 comments on commit 21848a1

Please sign in to comment.