From 931d32619ecf0535751aaec6cbfa95fe70135e39 Mon Sep 17 00:00:00 2001 From: David Rajaratnam Date: Fri, 1 Apr 2022 11:32:00 +1100 Subject: [PATCH] Comparison operators for Predicate map directly to the Symbol object Based on discussions https://github.com/potassco/clorm/issues/97 the decision is to overload the Predicate instance comparison to simply call the underlying Symbol object. --- clorm/orm/core.py | 66 +++++--------------- clorm/orm/templating.py | 67 ++++++++++++++++++-- tests/test_orm_core.py | 131 +++++++++------------------------------- 3 files changed, 105 insertions(+), 159 deletions(-) diff --git a/clorm/orm/core.py b/clorm/orm/core.py index fa409c8..1934e12 100644 --- a/clorm/orm/core.py +++ b/clorm/orm/core.py @@ -2384,13 +2384,6 @@ def _generate_dynamic_predicate_functions(class_name: str, namespace: Dict): tmp.append(f"{f.name}_cltopy(raw_args[{idx}]), ") args_cltopy= "".join(tmp) - if pdefn.is_tuple: - hash_eval_self = "hash(self._field_values)" - hash_eval_instance = "hash(instance._field_values)" - else: - hash_eval_self = f"""hash((self._sign, "{pdefn.name}", self._field_values))""" - hash_eval_instance = f"""hash((instance._sign, "{pdefn.name}", instance._field_values))""" - expansions = {"args_signature": args_signature, "sign_check": sign_check, "args": args, @@ -2399,9 +2392,7 @@ def _generate_dynamic_predicate_functions(class_name: str, namespace: Dict): "check_complex": check_complex, "args_raw": args_raw, "sign_check_unify": sign_check_unify, - "args_cltopy": args_cltopy, - "hash_evaluation_self": hash_eval_self, - "hash_evaluation_instance": hash_eval_instance} + "args_cltopy": args_cltopy} bool_status = not pdefn.is_tuple or len(pdefn) > 0 template = PREDICATE_TEMPLATE.format(pdefn=pdefn, bool_status=bool_status) @@ -2911,61 +2902,33 @@ def __neg__(self): #-------------------------------------------------------------------------- def __eq__(self, other): """Overloaded boolean operator.""" - if isinstance(other, self.__class__): - return self._field_values == other._field_values and \ - self._sign == other._sign - if self._meta.is_tuple: - return self._field_values == other - elif isinstance(other, Predicate): - return False + if isinstance(other, Predicate): + return self._raw == other._raw return NotImplemented def __lt__(self, other): """Overloaded boolean operator.""" - - # If it is the same predicate class then compare the sign and fields - if isinstance(other, self.__class__): - - # Negative literals are less than positive literals - if self._sign != other._sign: return self._sign < other._sign - - return self._field_values < other._field_values - - # If different predicates then compare the raw value - elif isinstance(other, Predicate): + if isinstance(other, Predicate): return self._raw < other._raw - - # Else an error return NotImplemented - def __ge__(self, other): + def __le__(self, other): """Overloaded boolean operator.""" - result = self.__lt__(other) - if result is NotImplemented: return NotImplemented - return not result + if isinstance(other, Predicate): + return self._raw <= other._raw + return NotImplemented def __gt__(self, other): """Overloaded boolean operator.""" - - # If it is the same predicate class then compare the sign and fields - if isinstance(other, self.__class__): - # Positive literals are greater than negative literals - if self._sign != other._sign: return self._sign > other._sign - - return self._field_values > other._field_values - - # If different predicates then compare the raw value - if not isinstance(other, Predicate): + if isinstance(other, Predicate): return self._raw > other._raw - - # Else an error return NotImplemented - def __le__(self, other): + def __ge__(self, other): """Overloaded boolean operator.""" - result = self.__gt__(other) - if result is NotImplemented: return NotImplemented - return not result + if isinstance(other, Predicate): + return self._raw >= other._raw + return NotImplemented def __hash__(self): return self._hash @@ -2985,13 +2948,12 @@ def __getstate__(self): def __setstate__(self, newstate): self._field_values = newstate["_field_values"] self._sign = newstate["_sign"] - if self._meta.is_tuple: self._hash = hash(self._field_values) - else: self._hash = hash((self._sign, self._meta.name, self._field_values)) clingoargs=[] for f,v in zip(self._meta, self._field_values): clingoargs.append(v.symbol if f.defn.complex else f.defn.pytocl(v)) self._raw = Function(self._meta.name, clingoargs, self._sign) + self._hash = hash(self._raw) #------------------------------------------------------------------------------ diff --git a/clorm/orm/templating.py b/clorm/orm/templating.py index 1c04c1b..76d932b 100644 --- a/clorm/orm/templating.py +++ b/clorm/orm/templating.py @@ -65,12 +65,11 @@ def __init__(self, self._field_values = ({{%args%}}) - self._hash = {{%hash_evaluation_self%}} - - # Create the raw symbol + # Create the raw symbol and cache the hash self._raw = Function("{pdefn.name}", ({{%args_raw%}}), self._sign) + self._hash = hash(self._raw) @classmethod def _unify(cls: Type[_P], raw: AnySymbol) -> Optional[_P]: @@ -86,10 +85,9 @@ def _unify(cls: Type[_P], raw: AnySymbol) -> Optional[_P]: instance = cls.__new__(cls) instance._raw = raw - instance._hash = None + instance._hash = hash(raw) instance._sign = raw.positive instance._field_values = ({{%args_cltopy%}}) - instance._hash = {{%hash_evaluation_instance%}} return instance except (TypeError, ValueError): return None @@ -104,8 +102,67 @@ def __bool__(self): def __len__(self): return {pdefn.arity} + +def __eq__(self, other): + if self.__class__ == other.__class__: + return (self._sign == other._sign) and (self._field_values == other._field_values) + + return NotImplemented + +def __lt__(self, other): + if self.__class__ == other.__class__: + return (self._sign < other._sign) or (self._field_values < other._field_values) + + return NoImplemented """ +PREDICATE_EQ_NON_TUPLE =""" +def __eq__(self, other): + if self.__class__ == other.__class__: + return (self._sign == other._sign) and (self._field_values == other._field_values) + return False if ininstance(other, Predicate) else NotImplemented +""" + +PREDICATE_EQ_TUPLE =""" +def __eq__(self, other): + if self.__class__ == other.__class__: + return self._field_values == other._field_values + if isinstance(other, Predicate): + return self._field_values == other._field_values if other._meta.is_tuple else False + return other == self._field_values +""" + +PREDICATE_CMP_NON_TUPLE=r""" +def __{name}__(self, other): + if self.__class__ == other.__class__: + # Note: rely on False < True (True > False) and evaluation order + return (self._sign {op} other._sign) or (self._field_values {op} other._field_values) + + if isinstance(other, Predicate): + return self._raw {op} other._raw + + return NotImplemented +""" + +PREDICATE_CMP_TUPLE=r""" +def __{name}__(self, other): + if self.__class__ == other.__class__: + self._field_values {op} other._field_values + + if isinstance(other, Predicate): + return self._raw {op} other._raw + + return other {op} self._field_values +""" + +def make_predicate_cmp(name, op, is_tuple=False): + if is_tuple: + return PREDICATE_CMP_NON_TUPLE.format(name=name, op=op) + return PREDICATE_CMP_TUPLE.format(name=name, op=op) + + + + CHECK_SIGN_TEMPLATE = r""" # Check if the sign is allowed if self._sign != {sign}: diff --git a/tests/test_orm_core.py b/tests/test_orm_core.py index e0c200f..55ffa75 100644 --- a/tests/test_orm_core.py +++ b/tests/test_orm_core.py @@ -13,6 +13,7 @@ import unittest import datetime import operator +import itertools import enum import collections.abc as cabc @@ -242,13 +243,13 @@ def test_api_field_function(self): self.assertIsInstance(t, BaseField) self.assertIsInstance(t.complex[0].meta.field, StringField) self.assertIsInstance(t.complex[1].meta.field, IntegerField) - self.assertEqual(t.default, ("3",4)) + self.assertEqual(tuple([*t.default]), ("3",4)) with self.subTest("with custom field"): INLField = define_flat_list_field(IntegerField,name="INLField") t = field(INLField,default=[3,4,5]) self.assertTrue(isinstance(t, INLField)) - self.assertEqual(t.default, [3,4,5]) + self.assertEqual([*t.default], [3,4,5]) with self.subTest("with default factory"): t = field(IntegerField, default_factory=lambda: 42) @@ -267,10 +268,11 @@ def factory(): self.assertEqual(t, (StringField,(StringField,IntegerField))) t = field((StringField,(StringField,IntegerField)),default=("3",("1",4))) + tval = t.cltopy(t.pytocl(("3",("1",4)))) self.assertIsInstance(t, BaseField) self.assertIsInstance(t.complex[0].meta.field, StringField) self.assertIsInstance(t.complex[1].meta.field, BaseField) - self.assertEqual(t.default, ("3",("1",4))) + self.assertEqual(t.default, tval) def test_api_field_function_illegal_arguments(self): with self.subTest("illegal basefield type"): @@ -1089,17 +1091,16 @@ class T(Predicate): self.assertTrue(p1.tuple_ == p1_alt.tuple_) self.assertTrue(p1.tuple_ == q1.tuple_) self.assertTrue(p1.tuple_ == t1.tuple_) - self.assertTrue(p1.tuple_ == tuple1) + + self.assertFalse(p1.tuple_ == tuple1) self.assertNotEqual(type(p1.tuple_), type(t1.tuple_)) -# self.assertNotEqual(type(p1.tuple_), type(t1)) self.assertTrue(p1.tuple_ != p2.tuple_) self.assertTrue(p1.tuple_ != q2.tuple_) self.assertTrue(p1.tuple_ != r2.tuple_) self.assertTrue(p1.tuple_ != s2.tuple_) self.assertTrue(p1.tuple_ != t2.tuple_) - self.assertTrue(p1.tuple_ != tuple2) #-------------------------------------------------------------------------- # Test predicates with default fields @@ -1918,35 +1919,6 @@ class Meta: is_tuple = True pos_f1=F3(1,sign=True) self.assertEqual(F3._unify(pos_raw), None) - #-------------------------------------------------------------------------- - # Test predicate equality - # -------------------------------------------------------------------------- - def test_predicate_comparison_operator_overload_signed(self): - class P(Predicate): - a = IntegerField - class Q(Predicate): - a = IntegerField - - p1 = P(1) ; neg_p1=P(1,sign=False) ; p2 = P(2) ; neg_p2=P(2,sign=False) - q1 = Q(1) - - self.assertTrue(neg_p1 < neg_p2) - self.assertTrue(neg_p1 < p1) - self.assertTrue(neg_p1 < p2) - self.assertTrue(neg_p2 < p1) - self.assertTrue(neg_p2 < p2) - self.assertTrue(p1 < p2) - - self.assertTrue(p2 > p1) - self.assertTrue(p2 > neg_p2) - self.assertTrue(p2 > neg_p1) - self.assertTrue(p1 > neg_p2) - self.assertTrue(p1 > neg_p1) - self.assertTrue(neg_p2 > neg_p1) - - # Different predicate sub-classes are incomparable -# with self.assertRaises(TypeError) as ctx: -# self.assertTrue(p1 < q1) #-------------------------------------------------------------------------- # Test a simple predicate with a field that has a function default @@ -2214,41 +2186,7 @@ class Fact(Predicate): #-------------------------------------------------------------------------- # Test predicate equality # -------------------------------------------------------------------------- - def test_predicate_comparison_operator_overloads(self): - - f1 = Function("fact", [Number(1)]) - f2 = Function("fact", [Number(2)]) - - class Fact(Predicate): - anum = IntegerField() - - af1 = Fact(anum=1) - af2 = Fact(anum=2) - af1_c = Fact(anum=1) - - self.assertEqual(f1, af1.raw) - self.assertEqual(af1, af1_c) - self.assertNotEqual(af1, af2) - self.assertEqual(str(f1), str(af1)) - - # comparing predicates of different types or to a raw should return - # false even if the underlying raw symbol is identical - class Fact2(Predicate): - anum = IntegerField() - class Meta: name = "fact" - ag1 = Fact2(anum=1) - - self.assertEqual(f1, af1.raw) - self.assertEqual(af1.raw, f1) - self.assertEqual(af1.raw, ag1.raw) - self.assertNotEqual(af1, ag1) - self.assertNotEqual(af1, f1) - self.assertNotEqual(f1, af1) - - self.assertTrue(af1 < af2) - self.assertTrue(af1 <= af2) - self.assertTrue(af2 > af1) - self.assertTrue(af2 >= af1) + def test_predicate_comparison_operator_overloads_with_symbol(self): # clingo.Symbol currently does not implement NotImplemented for # comparison between Symbol and some unknown type so the following @@ -2263,40 +2201,29 @@ class Meta: name = "fact" self.assertTrue(af1 <= f2) self.assertTrue(f2 >= af1) + #-------------------------------------------------------------------------- - # Test predicate equality + # Test predicate equality - comparison between predicate instances is just + # following the comparison of the underlying symbol objects. # -------------------------------------------------------------------------- - def test_comparison_operator_overloads_complex(self): - - class SwapField(IntegerField): - pytocl = lambda x: 100 - x - cltopy = lambda x: 100 - x - - class AComplex(ComplexTerm): - swap=SwapField() - norm=IntegerField() - - f1 = AComplex(swap=99,norm=1) - f2 = AComplex(swap=98,norm=2) - f3 = AComplex(swap=97,norm=3) - f4 = AComplex(swap=97,norm=3) - - rf1 = f1.raw - rf2 = f2.raw - rf3 = f3.raw - for rf in [rf1,rf2,rf3]: - self.assertEqual(rf.arguments[0],rf.arguments[1]) - - # Test the the comparison operator for the complex term is using the - # swapped values so that the comparison is opposite to what the raw - # field says. - self.assertTrue(rf1 < rf2) - self.assertTrue(rf2 < rf3) - self.assertTrue(f1 > f2) - self.assertTrue(f2 > f3) - self.assertTrue(f2 < f1) - self.assertTrue(f3 < f2) - self.assertEqual(f3,f4) + def test_predicate_comparison_operator_overloads(self): + class P(Predicate): + a = IntegerField + class Q(Predicate): + a = (IntegerField, StringField) + + p1 = P(1) ; neg_p1=P(1,sign=False) ; p2 = P(2) ; neg_p2=P(2,sign=False) + q1 = Q((1,"a")) ; neg_q1=q1.clone(sign=False) + q2 = Q((2,"b")) ; neg_q2=q2.clone(sign=False) + + operators = (operator.lt, operator.le, operator.eq, + operator.ne, operator.ge, operator.gt) + facts = (p1, neg_p1, p2, neg_p2, q1, neg_q1, q2, neg_q2) + + for x,y in itertools.product(facts, repeat=2): + for op in operators: + self.assertTrue(op(x,y) == op(x.raw,y.raw)) + #-------------------------------------------------------------------------- # Test unifying a symbol with a predicate # --------------------------------------------------------------------------