Skip to content

Commit

Permalink
Merge pull request #1482 from brian-team/ast_deprecation_fixes
Browse files Browse the repository at this point in the history
AST deprecation fixes
  • Loading branch information
mstimberg authored Sep 7, 2023
2 parents a7c8e35 + bb154dd commit c2789b2
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 88 deletions.
54 changes: 17 additions & 37 deletions brian2/codegen/optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
brian_dtype_from_dtype,
dtype_hierarchy,
)
from brian2.parsing.rendering import NodeRenderer, get_node_value
from brian2.parsing.rendering import NodeRenderer
from brian2.utils.stringtools import get_identifiers, word_substitute

from .statements import Statement
Expand Down Expand Up @@ -167,14 +167,14 @@ def _replace_with_zero(zero_node, node):
Parameters
----------
zero_node : `ast.Num`
zero_node : `ast.Constant`
The node to replace
node : `ast.Node`
The node that determines the type
Returns
-------
zero_node : `ast.Num`
zero_node : `ast.Constant`
The original ``zero_node`` with its value replaced by 0 or 0.0.
"""
# must not change the dtype of the output,
Expand Down Expand Up @@ -244,10 +244,7 @@ def render_node(self, node):
else:
val = prefs.core.default_float_dtype(val)
if node.dtype != "boolean":
if hasattr(ast, "Constant"):
newnode = ast.Constant(val)
else:
newnode = ast.Num(val)
newnode = ast.Constant(val)
newnode.dtype = node.dtype
newnode.scalar = True
newnode.stateless = node.stateless
Expand All @@ -274,7 +271,7 @@ def render_BinOp(self, node):
if op.__class__.__name__ == "Mult":
for operand, other in [(left, right), (right, left)]:
if operand.__class__.__name__ in ["Num", "Constant"]:
op_value = get_node_value(operand)
op_value = operand.value
if op_value == 0:
# Do not remove stateful functions
if node.stateless:
Expand All @@ -289,23 +286,20 @@ def render_BinOp(self, node):
# Handle division by 1, or 0/x
elif op.__class__.__name__ == "Div":
if (
left.__class__.__name__ in ["Num", "Constant"]
and get_node_value(left) == 0
left.__class__.__name__ in ["Num", "Constant"] and left.value == 0
): # 0/x
if node.stateless:
# Do not remove stateful functions
return _replace_with_zero(left, node)
if (
right.__class__.__name__ in ["Num", "Constant"]
and get_node_value(right) == 1
right.__class__.__name__ in ["Num", "Constant"] and right.value == 1
): # x/1
# only simplify this if the type wouldn't be cast by the operation
if dtype_hierarchy[right.dtype] <= dtype_hierarchy[left.dtype]:
return left
elif op.__class__.__name__ == "FloorDiv":
if (
left.__class__.__name__ in ["Num", "Constant"]
and get_node_value(left) == 0
left.__class__.__name__ in ["Num", "Constant"] and left.value == 0
): # 0//x
if node.stateless:
# Do not remove stateful functions
Expand All @@ -316,25 +310,22 @@ def render_BinOp(self, node):
if (
left.dtype == right.dtype == "integer"
and right.__class__.__name__ in ["Num", "Constant"]
and get_node_value(right) == 1
and right.value == 1
): # x//1
return left
# Handle addition of 0
elif op.__class__.__name__ == "Add":
for operand, other in [(left, right), (right, left)]:
if (
operand.__class__.__name__ in ["Num", "Constant"]
and get_node_value(operand) == 0
and operand.value == 0
):
# only simplify this if the type wouldn't be cast by the operation
if dtype_hierarchy[operand.dtype] <= dtype_hierarchy[other.dtype]:
return other
# Handle subtraction of 0
elif op.__class__.__name__ == "Sub":
if (
right.__class__.__name__ in ["Num", "Constant"]
and get_node_value(right) == 0
):
if right.__class__.__name__ in ["Num", "Constant"] and right.value == 0:
# only simplify this if the type wouldn't be cast by the operation
if dtype_hierarchy[right.dtype] <= dtype_hierarchy[left.dtype]:
return left
Expand All @@ -349,12 +340,10 @@ def render_BinOp(self, node):
]:
for subnode in [node.left, node.right]:
if subnode.__class__.__name__ in ["Num", "Constant"] and not (
get_node_value(subnode) is True or get_node_value(subnode) is False
subnode.value is True or subnode.value is False
):
subnode.dtype = "float"
subnode.value = prefs.core.default_float_dtype(
get_node_value(subnode)
)
subnode.value = prefs.core.default_float_dtype(subnode.value)
return node


Expand Down Expand Up @@ -588,10 +577,7 @@ def collect(node):
# if the fully evaluated node is just the identity/null element then we
# don't have to make it into an explicit term
if x != op_null:
if hasattr(ast, "Constant"):
num_node = ast.Constant(x)
else:
num_node = ast.Num(x)
num_node = ast.Constant(x)
else:
num_node = None
terms_primary = remaining_terms_primary
Expand All @@ -612,22 +598,16 @@ def collect(node):
node = reduced_node([node, prod_primary], op_primary)
if prod_inverted is not None:
if node is None:
if hasattr(ast, "Constant"):
node = ast.Constant(op_null_with_dtype)
else:
node = ast.Num(op_null_with_dtype)
node = ast.Constant(op_null_with_dtype)
node = ast.BinOp(node, op_inverted(), prod_inverted)

if node is None: # everything cancelled
if hasattr(ast, "Constant"):
node = ast.Constant(op_null_with_dtype)
else:
node = ast.Num(op_null_with_dtype)
node = ast.Constant(op_null_with_dtype)
if (
hasattr(node, "dtype")
and dtype_hierarchy[node.dtype] < dtype_hierarchy[orignode_dtype]
):
node = ast.BinOp(ast.Num(op_null_with_dtype), op_primary(), node)
node = ast.BinOp(ast.Constant(op_null_with_dtype), op_primary(), node)
node.collected = True
return node

Expand Down
5 changes: 2 additions & 3 deletions brian2/parsing/bast.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

import numpy

from brian2.parsing.rendering import get_node_value
from brian2.utils.logger import get_logger

__all__ = ["brian_ast", "BrianASTRenderer", "dtype_hierarchy"]
Expand Down Expand Up @@ -168,12 +167,12 @@ def render_Name(self, node):

def render_Num(self, node):
node.complexity = 0
node.dtype = brian_dtype_from_value(get_node_value(node))
node.dtype = brian_dtype_from_value(node.value)
node.scalar = True
node.stateless = True
return node

def render_Constant(self, node): # For literals in Python 3.8
def render_Constant(self, node): # For literals in Python >= 3.8
if node.value is True or node.value is False or node.value is None:
return self.render_NameConstant(node)
else:
Expand Down
33 changes: 4 additions & 29 deletions brian2/parsing/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,7 @@ def is_boolean_expression(expr, variables):
raise SyntaxError(
"Expression ought to be boolean but is not (e.g. 'x<y and 3')"
)
elif expr.__class__ in [
getattr(ast, "NameConstant", None),
getattr(ast, "Constant", None),
]:
elif expr.__class__ is ast.Constant:
value = expr.value
if value is True or value is False:
return True
Expand Down Expand Up @@ -140,21 +137,8 @@ def _get_value_from_expression(expr, variables):
return 1.0 if name == "True" else 0.0
else:
raise ValueError(f"Unknown identifier {name}")
elif expr.__class__ is getattr(ast, "NameConstant", None):
value = expr.value
if value is True or value is False:
return 1.0 if value else 0.0
else:
raise ValueError(f"Do not know how to deal with value {value}")
elif expr.__class__ is ast.Num or expr.__class__ is getattr(
ast, "Constant", None
): # Python 3.8
# In Python 3.8, boolean values are represented by Constant, not by
# NameConstant
if expr.n is True or expr.n is False:
return 1.0 if expr.n else 0.0
else:
return expr.n
elif expr.__class__ is ast.Constant:
return expr.value
elif expr.__class__ is ast.BoolOp:
raise SyntaxError(
"Cannot determine the numerical value for a boolean operation."
Expand Down Expand Up @@ -231,13 +215,6 @@ def parse_expression_dimensions(expr, variables, orig_expr=None):
orig_expr = expr
mod = ast.parse(expr, mode="eval")
expr = mod.body
if expr.__class__ is getattr(ast, "NameConstant", None):
# new class for True, False, None in Python 3.4
value = expr.value
if value is True or value is False:
return DIMENSIONLESS
else:
raise ValueError(f"Do not know how to handle value {value}")
if expr.__class__ is ast.Name:
name = expr.id
# Raise an error if a function is called as if it were a variable
Expand All @@ -253,9 +230,7 @@ def parse_expression_dimensions(expr, variables, orig_expr=None):
return DIMENSIONLESS
else:
raise KeyError(f"Unknown identifier {name}")
elif expr.__class__ is ast.Num or expr.__class__ is getattr(
ast, "Constant", None
): # Python 3.8
elif expr.__class__ is ast.Constant:
return DIMENSIONLESS
elif expr.__class__ is ast.BoolOp:
# check that the units are valid in each subexpression
Expand Down
26 changes: 7 additions & 19 deletions brian2/parsing/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,9 @@
"NumpyNodeRenderer",
"CPPNodeRenderer",
"SympyNodeRenderer",
"get_node_value",
]


def get_node_value(node):
"""Helper function to mask differences between Python versions"""
value = getattr(node, "n", getattr(node, "value", None))
if value is None:
raise AttributeError(f'Node {node} has neither "n" nor "value" attribute')
return value


class NodeRenderer:
expression_ops = {
# BinOp
Expand Down Expand Up @@ -90,9 +81,9 @@ def render_Name(self, node):
return node.id

def render_Num(self, node):
return repr(get_node_value(node))
return repr(node.value)

def render_Constant(self, node): # For literals in Python 3.8
def render_Constant(self, node):
if node.value is True or node.value is False or node.value is None:
return self.render_NameConstant(node)
else:
Expand Down Expand Up @@ -121,11 +112,9 @@ def render_element_parentheses(self, node):
Render an element with parentheses around it or leave them away for
numbers, names and function calls.
"""
if node.__class__.__name__ in ["Name", "NameConstant"]:
if node.__class__.__name__ == "Name":
return self.render_node(node)
elif (
node.__class__.__name__ in ["Num", "Constant"] and get_node_value(node) >= 0
):
elif node.__class__.__name__ in ["Num", "Constant"] and node.value >= 0:
return self.render_node(node)
elif node.__class__.__name__ == "Call":
return self.render_node(node)
Expand Down Expand Up @@ -278,10 +267,10 @@ def render_NameConstant(self, node):
return str(node.value)

def render_Num(self, node):
if isinstance(node.n, numbers.Integral):
return sympy.Integer(node.n)
if isinstance(node.value, numbers.Integral):
return sympy.Integer(node.value)
else:
return sympy.Float(node.n)
return sympy.Float(node.value)

def render_BinOp(self, node):
op_name = node.op.__class__.__name__
Expand Down Expand Up @@ -356,7 +345,6 @@ def render_BinOp(self, node):
return NodeRenderer.render_BinOp(self, node)

def render_NameConstant(self, node):
# In Python 3.4, None, True and False go here
return {True: "true", False: "false"}.get(node.value, node.value)

def render_Name(self, node):
Expand Down

0 comments on commit c2789b2

Please sign in to comment.