diff --git a/loopy/symbolic.py b/loopy/symbolic.py index b6bd1d009..5f829051b 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -50,6 +50,8 @@ CSECachingMapperMixin, ) import immutables +from pymbolic.mapper.equality import ( + EqualityMapper as EqualityMapperBase) from pymbolic.mapper.evaluator import \ CachedEvaluationMapper as EvaluationMapperBase from pymbolic.mapper.substitutor import \ @@ -502,6 +504,60 @@ def map_substitution(self, name, rule, arguments): return self.rec(expr) + +class EqualityMapper(EqualityMapperBase): + def map_loopy_function_identifier(self, expr, other) -> bool: + return True + + def map_reduction(self, expr, other) -> bool: + return ( + expr.operation == other.operation + and expr.allow_simultaneous == other.allow_simultaneous + and self.rec(expr.expr, other.expr) + and all(iname == other_iname + for iname, other_iname in zip(expr.inames, other.inames))) + + def map_group_hw_index(self, expr, other) -> bool: + return expr.axis == other.axis + + map_local_hw_index = map_group_hw_index + + def map_linear_subscript(self, expr, other) -> bool: + return ( + self.rec(expr.index, other.index) + and self.rec(expr.aggregate, other.aggregate)) + + def map_rule_argument(self, expr, other) -> bool: + return expr.index == other.index + + def map_resolved_function(self, expr, other) -> bool: + return self.rec(expr.function, other.function) + + def map_sub_array_ref(self, expr, other) -> bool: + return ( + len(expr.swept_inames) == len(other.swept_inames) + and self.rec(expr.subscript, other.subscript) + and all(self.rec(iname, other_iname) + for iname, other_iname in zip( + expr.swept_inames, + other.swept_inames)) + ) + + def map_tagged_variable(self, expr, other) -> bool: + return ( + expr.name == other.name + and all(tag == other_tag + for tag, other_tag in zip(expr.tags, other.tags)) + ) + + def map_type_cast(self, expr, other) -> bool: + return ( + expr.type == other.type + and self.rec(expr.child, other.child)) + + def map_fortran_division(self, expr, other) -> bool: + return self.map_quotient(expr, other) + # }}} @@ -515,6 +571,9 @@ def stringifier(self): def make_stringifier(self, originating_stringifier=None): return StringifyMapper() + def make_equality_mapper(self): + return EqualityMapper() + class Literal(LoopyExpressionBase): """A literal to be used during code generation. @@ -522,8 +581,8 @@ class Literal(LoopyExpressionBase): .. note:: Only used in the output of - :mod:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and - similar mappers). Not for use in Loopy source representation. + :class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` + (and similar mappers). Not for use in :mod:`loopy` source representation. """ def __init__(self, s): @@ -543,8 +602,8 @@ class ArrayLiteral(LoopyExpressionBase): .. note:: Only used in the output of - :mod:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and - similar mappers). Not for use in Loopy source representation. + :class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` + (and similar mappers). Not for use in :mod:`loopy` source representation. """ def __init__(self, children): @@ -573,8 +632,8 @@ class GroupHardwareAxisIndex(HardwareAxisIndex): .. note:: Only used in the output of - :mod:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and - similar mappers). Not for use in Loopy source representation. + :class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` + (and similar mappers). Not for use in :mod:`loopy` source representation. """ mapper_method = "map_group_hw_index" @@ -584,8 +643,8 @@ class LocalHardwareAxisIndex(HardwareAxisIndex): .. note:: Only used in the output of - :mod:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and - similar mappers). Not for use in Loopy source representation. + :class:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and + similar mappers). Not for use in :mod:`loopy` source representation. """ mapper_method = "map_local_hw_index" @@ -792,12 +851,6 @@ def __getinitargs__(self): def get_hash(self): return hash((self.__class__, self.operation, self.inames, self.expr)) - def is_equal(self, other): - return (other.__class__ == self.__class__ - and other.operation == self.operation - and other.inames == self.inames - and other.expr == self.expr) - @property def is_tuple_typed(self): return self.operation.arg_count > 1 @@ -994,14 +1047,6 @@ def __getinitargs__(self): def get_hash(self): return hash((self.__class__, self.swept_inames, self.subscript)) - def is_equal(self, other): - """ - Returns *True* iff the sub-array refs have identical expressions. - """ - return (other.__class__ == self.__class__ - and other.subscript == self.subscript - and other.swept_inames == self.swept_inames) - def make_stringifier(self, originating_stringifier=None): return StringifyMapper() diff --git a/requirements.txt b/requirements.txt index c44f010c3..82d76f5db 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ git+https://github.com/inducer/pytools.git#egg=pytools >= 2021.1 git+https://github.com/inducer/islpy.git#egg=islpy git+https://github.com/inducer/cgen.git#egg=cgen git+https://github.com/inducer/pyopencl.git#egg=pyopencl -git+https://github.com/inducer/pymbolic.git#egg=pymbolic +git+https://github.com/alexfikl/pymbolic.git@equality-mapper#egg=pymbolic git+https://github.com/inducer/genpy.git#egg=genpy git+https://github.com/inducer/codepy.git#egg=codepy