diff --git a/brian2/codegen/optimisation.py b/brian2/codegen/optimisation.py index a73e0eeef..73fa90966 100644 --- a/brian2/codegen/optimisation.py +++ b/brian2/codegen/optimisation.py @@ -17,7 +17,7 @@ brian_dtype_from_dtype, dtype_hierarchy, ) -from brian2.parsing.rendering import NodeRenderer, get_node_value +from brian2.parsing.rendering import NodeRenderer from brian2.utils.stringtools import get_identifiers, word_substitute from .statements import Statement @@ -271,7 +271,7 @@ def render_BinOp(self, node): if op.__class__.__name__ == "Mult": for operand, other in [(left, right), (right, left)]: if operand.__class__.__name__ in ["Num", "Constant"]: - op_value = get_node_value(operand) + op_value = operand.value if op_value == 0: # Do not remove stateful functions if node.stateless: @@ -286,23 +286,20 @@ def render_BinOp(self, node): # Handle division by 1, or 0/x elif op.__class__.__name__ == "Div": if ( - left.__class__.__name__ in ["Num", "Constant"] - and get_node_value(left) == 0 + left.__class__.__name__ in ["Num", "Constant"] and left.value == 0 ): # 0/x if node.stateless: # Do not remove stateful functions return _replace_with_zero(left, node) if ( - right.__class__.__name__ in ["Num", "Constant"] - and get_node_value(right) == 1 + right.__class__.__name__ in ["Num", "Constant"] and right.value == 1 ): # x/1 # only simplify this if the type wouldn't be cast by the operation if dtype_hierarchy[right.dtype] <= dtype_hierarchy[left.dtype]: return left elif op.__class__.__name__ == "FloorDiv": if ( - left.__class__.__name__ in ["Num", "Constant"] - and get_node_value(left) == 0 + left.__class__.__name__ in ["Num", "Constant"] and left.value == 0 ): # 0//x if node.stateless: # Do not remove stateful functions @@ -313,7 +310,7 @@ def render_BinOp(self, node): if ( left.dtype == right.dtype == "integer" and right.__class__.__name__ in ["Num", "Constant"] - and get_node_value(right) == 1 + and right.value == 1 ): # x//1 return left # Handle addition of 0 @@ -321,17 +318,14 @@ def render_BinOp(self, node): for operand, other in [(left, right), (right, left)]: if ( operand.__class__.__name__ in ["Num", "Constant"] - and get_node_value(operand) == 0 + and operand.value == 0 ): # only simplify this if the type wouldn't be cast by the operation if dtype_hierarchy[operand.dtype] <= dtype_hierarchy[other.dtype]: return other # Handle subtraction of 0 elif op.__class__.__name__ == "Sub": - if ( - right.__class__.__name__ in ["Num", "Constant"] - and get_node_value(right) == 0 - ): + if right.__class__.__name__ in ["Num", "Constant"] and right.value == 0: # only simplify this if the type wouldn't be cast by the operation if dtype_hierarchy[right.dtype] <= dtype_hierarchy[left.dtype]: return left @@ -346,12 +340,10 @@ def render_BinOp(self, node): ]: for subnode in [node.left, node.right]: if subnode.__class__.__name__ in ["Num", "Constant"] and not ( - get_node_value(subnode) is True or get_node_value(subnode) is False + subnode.value is True or subnode.value is False ): subnode.dtype = "float" - subnode.value = prefs.core.default_float_dtype( - get_node_value(subnode) - ) + subnode.value = prefs.core.default_float_dtype(subnode.value) return node diff --git a/brian2/parsing/bast.py b/brian2/parsing/bast.py index 6e7b6e7e5..682590c23 100644 --- a/brian2/parsing/bast.py +++ b/brian2/parsing/bast.py @@ -8,7 +8,6 @@ import numpy -from brian2.parsing.rendering import get_node_value from brian2.utils.logger import get_logger __all__ = ["brian_ast", "BrianASTRenderer", "dtype_hierarchy"] @@ -168,7 +167,7 @@ def render_Name(self, node): def render_Num(self, node): node.complexity = 0 - node.dtype = brian_dtype_from_value(get_node_value(node)) + node.dtype = brian_dtype_from_value(node.value) node.scalar = True node.stateless = True return node diff --git a/brian2/parsing/expressions.py b/brian2/parsing/expressions.py index 18aa8fb35..559532422 100644 --- a/brian2/parsing/expressions.py +++ b/brian2/parsing/expressions.py @@ -4,7 +4,7 @@ import ast from brian2.core.functions import Function -from brian2.parsing.rendering import NodeRenderer, get_node_value +from brian2.parsing.rendering import NodeRenderer from brian2.units.fundamentalunits import ( DIMENSIONLESS, DimensionMismatchError, @@ -138,7 +138,7 @@ def _get_value_from_expression(expr, variables): else: raise ValueError(f"Unknown identifier {name}") elif expr.__class__ is ast.Constant: - return get_node_value(expr) + return expr.value elif expr.__class__ is ast.BoolOp: raise SyntaxError( "Cannot determine the numerical value for a boolean operation." diff --git a/brian2/parsing/rendering.py b/brian2/parsing/rendering.py index 8c0193299..8044ac92a 100644 --- a/brian2/parsing/rendering.py +++ b/brian2/parsing/rendering.py @@ -10,25 +10,9 @@ "NumpyNodeRenderer", "CPPNodeRenderer", "SympyNodeRenderer", - "get_node_value", ] -def get_node_value(node): - """Helper function to mask differences between Python versions""" - try: - value = node.value - except AttributeError: - try: - value = node.n - except AttributeError: - value = None - - if value is None: - raise AttributeError(f'Node {node} has neither "n" nor "value" attribute') - return value - - class NodeRenderer: expression_ops = { # BinOp @@ -97,9 +81,9 @@ def render_Name(self, node): return node.id def render_Num(self, node): - return repr(get_node_value(node)) + return repr(node.value) - def render_Constant(self, node): # For literals in Python 3.8 + def render_Constant(self, node): if node.value is True or node.value is False or node.value is None: return self.render_NameConstant(node) else: @@ -130,9 +114,7 @@ def render_element_parentheses(self, node): """ if node.__class__.__name__ == "Name": return self.render_node(node) - elif ( - node.__class__.__name__ in ["Num", "Constant"] and get_node_value(node) >= 0 - ): + elif node.__class__.__name__ in ["Num", "Constant"] and node.value >= 0: return self.render_node(node) elif node.__class__.__name__ == "Call": return self.render_node(node) @@ -285,10 +267,10 @@ def render_NameConstant(self, node): return str(node.value) def render_Num(self, node): - if isinstance(get_node_value(node), numbers.Integral): - return sympy.Integer(get_node_value(node)) + if isinstance(node.value, numbers.Integral): + return sympy.Integer(node.value) else: - return sympy.Float(get_node_value(node)) + return sympy.Float(node.value) def render_BinOp(self, node): op_name = node.op.__class__.__name__