diff --git a/brian2/codegen/optimisation.py b/brian2/codegen/optimisation.py index da8d7f906..73fa90966 100644 --- a/brian2/codegen/optimisation.py +++ b/brian2/codegen/optimisation.py @@ -17,7 +17,7 @@ brian_dtype_from_dtype, dtype_hierarchy, ) -from brian2.parsing.rendering import NodeRenderer, get_node_value +from brian2.parsing.rendering import NodeRenderer from brian2.utils.stringtools import get_identifiers, word_substitute from .statements import Statement @@ -167,14 +167,14 @@ def _replace_with_zero(zero_node, node): Parameters ---------- - zero_node : `ast.Num` + zero_node : `ast.Constant` The node to replace node : `ast.Node` The node that determines the type Returns ------- - zero_node : `ast.Num` + zero_node : `ast.Constant` The original ``zero_node`` with its value replaced by 0 or 0.0. """ # must not change the dtype of the output, @@ -244,10 +244,7 @@ def render_node(self, node): else: val = prefs.core.default_float_dtype(val) if node.dtype != "boolean": - if hasattr(ast, "Constant"): - newnode = ast.Constant(val) - else: - newnode = ast.Num(val) + newnode = ast.Constant(val) newnode.dtype = node.dtype newnode.scalar = True newnode.stateless = node.stateless @@ -274,7 +271,7 @@ def render_BinOp(self, node): if op.__class__.__name__ == "Mult": for operand, other in [(left, right), (right, left)]: if operand.__class__.__name__ in ["Num", "Constant"]: - op_value = get_node_value(operand) + op_value = operand.value if op_value == 0: # Do not remove stateful functions if node.stateless: @@ -289,23 +286,20 @@ def render_BinOp(self, node): # Handle division by 1, or 0/x elif op.__class__.__name__ == "Div": if ( - left.__class__.__name__ in ["Num", "Constant"] - and get_node_value(left) == 0 + left.__class__.__name__ in ["Num", "Constant"] and left.value == 0 ): # 0/x if node.stateless: # Do not remove stateful functions return _replace_with_zero(left, node) if ( - right.__class__.__name__ in ["Num", "Constant"] - and get_node_value(right) == 1 + right.__class__.__name__ in ["Num", "Constant"] and right.value == 1 ): # x/1 # only simplify this if the type wouldn't be cast by the operation if dtype_hierarchy[right.dtype] <= dtype_hierarchy[left.dtype]: return left elif op.__class__.__name__ == "FloorDiv": if ( - left.__class__.__name__ in ["Num", "Constant"] - and get_node_value(left) == 0 + left.__class__.__name__ in ["Num", "Constant"] and left.value == 0 ): # 0//x if node.stateless: # Do not remove stateful functions @@ -316,7 +310,7 @@ def render_BinOp(self, node): if ( left.dtype == right.dtype == "integer" and right.__class__.__name__ in ["Num", "Constant"] - and get_node_value(right) == 1 + and right.value == 1 ): # x//1 return left # Handle addition of 0 @@ -324,17 +318,14 @@ def render_BinOp(self, node): for operand, other in [(left, right), (right, left)]: if ( operand.__class__.__name__ in ["Num", "Constant"] - and get_node_value(operand) == 0 + and operand.value == 0 ): # only simplify this if the type wouldn't be cast by the operation if dtype_hierarchy[operand.dtype] <= dtype_hierarchy[other.dtype]: return other # Handle subtraction of 0 elif op.__class__.__name__ == "Sub": - if ( - right.__class__.__name__ in ["Num", "Constant"] - and get_node_value(right) == 0 - ): + if right.__class__.__name__ in ["Num", "Constant"] and right.value == 0: # only simplify this if the type wouldn't be cast by the operation if dtype_hierarchy[right.dtype] <= dtype_hierarchy[left.dtype]: return left @@ -349,12 +340,10 @@ def render_BinOp(self, node): ]: for subnode in [node.left, node.right]: if subnode.__class__.__name__ in ["Num", "Constant"] and not ( - get_node_value(subnode) is True or get_node_value(subnode) is False + subnode.value is True or subnode.value is False ): subnode.dtype = "float" - subnode.value = prefs.core.default_float_dtype( - get_node_value(subnode) - ) + subnode.value = prefs.core.default_float_dtype(subnode.value) return node @@ -588,10 +577,7 @@ def collect(node): # if the fully evaluated node is just the identity/null element then we # don't have to make it into an explicit term if x != op_null: - if hasattr(ast, "Constant"): - num_node = ast.Constant(x) - else: - num_node = ast.Num(x) + num_node = ast.Constant(x) else: num_node = None terms_primary = remaining_terms_primary @@ -612,22 +598,16 @@ def collect(node): node = reduced_node([node, prod_primary], op_primary) if prod_inverted is not None: if node is None: - if hasattr(ast, "Constant"): - node = ast.Constant(op_null_with_dtype) - else: - node = ast.Num(op_null_with_dtype) + node = ast.Constant(op_null_with_dtype) node = ast.BinOp(node, op_inverted(), prod_inverted) if node is None: # everything cancelled - if hasattr(ast, "Constant"): - node = ast.Constant(op_null_with_dtype) - else: - node = ast.Num(op_null_with_dtype) + node = ast.Constant(op_null_with_dtype) if ( hasattr(node, "dtype") and dtype_hierarchy[node.dtype] < dtype_hierarchy[orignode_dtype] ): - node = ast.BinOp(ast.Num(op_null_with_dtype), op_primary(), node) + node = ast.BinOp(ast.Constant(op_null_with_dtype), op_primary(), node) node.collected = True return node diff --git a/brian2/parsing/bast.py b/brian2/parsing/bast.py index 9b61d6865..682590c23 100644 --- a/brian2/parsing/bast.py +++ b/brian2/parsing/bast.py @@ -8,7 +8,6 @@ import numpy -from brian2.parsing.rendering import get_node_value from brian2.utils.logger import get_logger __all__ = ["brian_ast", "BrianASTRenderer", "dtype_hierarchy"] @@ -168,12 +167,12 @@ def render_Name(self, node): def render_Num(self, node): node.complexity = 0 - node.dtype = brian_dtype_from_value(get_node_value(node)) + node.dtype = brian_dtype_from_value(node.value) node.scalar = True node.stateless = True return node - def render_Constant(self, node): # For literals in Python 3.8 + def render_Constant(self, node): # For literals in Python >= 3.8 if node.value is True or node.value is False or node.value is None: return self.render_NameConstant(node) else: diff --git a/brian2/parsing/expressions.py b/brian2/parsing/expressions.py index cd76de036..559532422 100644 --- a/brian2/parsing/expressions.py +++ b/brian2/parsing/expressions.py @@ -70,10 +70,7 @@ def is_boolean_expression(expr, variables): raise SyntaxError( "Expression ought to be boolean but is not (e.g. 'x= 0 - ): + elif node.__class__.__name__ in ["Num", "Constant"] and node.value >= 0: return self.render_node(node) elif node.__class__.__name__ == "Call": return self.render_node(node) @@ -278,10 +267,10 @@ def render_NameConstant(self, node): return str(node.value) def render_Num(self, node): - if isinstance(node.n, numbers.Integral): - return sympy.Integer(node.n) + if isinstance(node.value, numbers.Integral): + return sympy.Integer(node.value) else: - return sympy.Float(node.n) + return sympy.Float(node.value) def render_BinOp(self, node): op_name = node.op.__class__.__name__ @@ -356,7 +345,6 @@ def render_BinOp(self, node): return NodeRenderer.render_BinOp(self, node) def render_NameConstant(self, node): - # In Python 3.4, None, True and False go here return {True: "true", False: "false"}.get(node.value, node.value) def render_Name(self, node):