Skip to content

Commit

Permalink
Updates (#315)
Browse files Browse the repository at this point in the history
Updates (#315)
  • Loading branch information
ztqakita authored Dec 29, 2022
2 parents 3ebeb2a + d535167 commit 6b86740
Show file tree
Hide file tree
Showing 158 changed files with 5,468 additions and 3,834 deletions.
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ publishment.md
.vscode
io_test_tmp*

brainpy/base/tests/io_test_tmp*
brainpy/math/brainpy_object/tests/io_test_tmp*

development

Expand Down Expand Up @@ -217,3 +217,7 @@ cython_debug/
/docs/apis/simulation/generated/
!/brainpy/dyn/tests/data/
/examples/dynamics_simulation/data/
/examples/training_snn_models/logs/T100_b64_lr0.001/
/examples/training_snn_models/logs/
/examples/training_snn_models/data/
/docs/tutorial_advanced/data/
32 changes: 14 additions & 18 deletions brainpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,8 @@
# fundamental modules
from . import errors, check, tools

# "base" module
from . import base
from .base import (
# base class
Base,
BrainPyObject,

# collector
Collector,
ArrayCollector,
TensorCollector,
)

# math foundation
from . import math
from . import modes

# toolboxes
from . import (
Expand Down Expand Up @@ -69,7 +55,7 @@
synouts, # synaptic output
synplast, # synaptic plasticity

# base classes
# brainpy_object classes
DynamicalSystem,
Container,
Sequential,
Expand Down Expand Up @@ -113,9 +99,7 @@

# running
from . import running
from .running import (
Runner
)
from .running import (Runner)

# "visualization" module, will be removed soon
from .visualization import visualize
Expand All @@ -124,3 +108,15 @@
conn = connect
init = initialize
optim = optimizers

from . import experimental


# deprecated
from . import base
# use ``brainpy.math.*`` instead
from brainpy.math.object_transform.base_object import (Base, BrainPyObject,)
# use ``brainpy.math.*`` instead
from brainpy.math.object_transform.collector import (Collector, ArrayCollector, TensorCollector,)
# use ``brainpy.math.*`` instead
from . import modes
5 changes: 2 additions & 3 deletions brainpy/algorithms/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
from jax.lax import while_loop

import brainpy.math as bm
from brainpy.base import BrainPyObject
from brainpy.types import ArrayType
from .utils import (Sigmoid,
Regularization, L1Regularization, L1L2Regularization, L2Regularization,
polynomial_features, normalize)

__all__ = [
# base class for offline training algorithm
# brainpy_object class for offline training algorithm
'OfflineAlgorithm',

# training methods
Expand All @@ -33,7 +32,7 @@
name2func = dict()


class OfflineAlgorithm(BrainPyObject):
class OfflineAlgorithm(bm.BrainPyObject):
"""Base class for offline training algorithm."""

def __init__(self, name=None):
Expand Down
5 changes: 2 additions & 3 deletions brainpy/algorithms/online.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# -*- coding: utf-8 -*-

import brainpy.math as bm
from brainpy.base import BrainPyObject
from jax import vmap
import jax.numpy as jnp

__all__ = [
# base class
# brainpy_object class
'OnlineAlgorithm',

# online learning algorithms
Expand All @@ -21,7 +20,7 @@
name2func = dict()


class OnlineAlgorithm(BrainPyObject):
class OnlineAlgorithm(bm.BrainPyObject):
"""Base class for online training algorithm."""

def __init__(self, name=None):
Expand Down
5 changes: 2 additions & 3 deletions brainpy/analysis/highdim/slow_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import brainpy.math as bm
from brainpy import optimizers as optim, losses
from brainpy.analysis import utils, base, constants
from brainpy.base import ArrayCollector
from brainpy.dyn.base import DynamicalSystem
from brainpy.dyn.runners import check_and_format_inputs, _f_ops
from brainpy.errors import AnalyzerError, UnsupportedError
Expand Down Expand Up @@ -133,11 +132,11 @@ def __init__(

# update function
if target_vars is None:
self.target_vars = ArrayCollector()
self.target_vars = bm.ArrayCollector()
else:
if not isinstance(target_vars, dict):
raise TypeError(f'"target_vars" must be a dict but we got {type(target_vars)}')
self.target_vars = ArrayCollector(target_vars)
self.target_vars = bm.ArrayCollector(target_vars)
excluded_vars = () if excluded_vars is None else excluded_vars
if isinstance(excluded_vars, dict):
excluded_vars = tuple(excluded_vars.values())
Expand Down
9 changes: 4 additions & 5 deletions brainpy/analysis/lowdim/lowdim_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from brainpy import errors, tools
from brainpy.analysis import constants as C, utils
from brainpy.analysis.base import DSAnalyzer
from brainpy.base.collector import Collector

pyplot = None

Expand Down Expand Up @@ -91,7 +90,7 @@ def __init__(
if not isinstance(target_vars, dict):
raise errors.AnalyzerError('"target_vars" must be a dict, with the format of '
'{"var1": (var1_min, var1_max)}.')
self.target_vars = Collector(target_vars)
self.target_vars = bm.Collector(target_vars)
self.target_var_names = list(self.target_vars.keys()) # list of target vars
for key in self.target_vars.keys():
if key not in self.model.variables:
Expand All @@ -110,7 +109,7 @@ def __init__(
for key in fixed_vars.keys():
if key not in self.model.variables:
raise ValueError(f'{key} is not a dynamical variable in {self.model}.')
self.fixed_vars = Collector(fixed_vars)
self.fixed_vars = bm.Collector(fixed_vars)

# check duplicate
for key in self.fixed_vars.keys():
Expand All @@ -125,7 +124,7 @@ def __init__(
if not isinstance(pars_update, dict):
raise errors.AnalyzerError('"pars_update" must be a dict with the format '
'of {"par1": val1, "par2": val2}.')
pars_update = Collector(pars_update)
pars_update = bm.Collector(pars_update)
for key in pars_update.keys():
if key not in self.model.parameters:
raise errors.AnalyzerError(f'"{key}" is not a valid parameter in "{self.model}" model.')
Expand All @@ -144,7 +143,7 @@ def __init__(
raise errors.AnalyzerError(
f'The range of parameter {key} is reversed, which means {value[0]} should be smaller than {value[1]}.')

self.target_pars = Collector(target_pars)
self.target_pars = bm.Collector(target_pars)
self.target_par_names = list(self.target_pars.keys()) # list of target_pars

# check duplicate
Expand Down
30 changes: 8 additions & 22 deletions brainpy/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,14 @@
# -*- coding: utf-8 -*-

"""
The ``base`` module for whole BrainPy ecosystem.
- This module provides the most fundamental class ``BrainPyObject``,
and its associated helper class ``Collector`` and ``ArrayCollector``.
- For each instance of "BrainPyObject" class, users can retrieve all
the variables (or trainable variables), integrators, and nodes.
- This module also provides a ``FunAsObject`` class to wrap user-defined
functions. In each function, maybe several nodes are used, and
users can initialize a ``FunAsObject`` by providing the nodes used
in the function. Unfortunately, ``FunAsObject`` class does not have
the ability to gather nodes automatically.
- This module provides ``io`` helper functions to help users save/load
model states, or share user's customized model with others.
- This module provides ``naming`` tools to guarantee the unique nameing
for each BrainPyObject object.
Details please see the following.
This module is deprecated since version 2.3.1.
Please use ``brainpy.math.*`` instead.
"""

from brainpy.base.base import *
from brainpy.base.collector import *
from brainpy.base.function import *
from brainpy.base.io import *
from brainpy.base.naming import *

from .base import *
from .collector import *
from .function import *
from .io import *
from .naming import *

Loading

0 comments on commit 6b86740

Please sign in to comment.