Skip to content

Commit

Permalink
Comparison operators for Predicate map directly to the Symbol object
Browse files Browse the repository at this point in the history
Based on discussions #97

the decision is to overload the Predicate instance comparison to simply call
the underlying Symbol object.
  • Loading branch information
daveraja committed Apr 1, 2022
1 parent ee82644 commit 931d326
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 159 deletions.
66 changes: 14 additions & 52 deletions clorm/orm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2384,13 +2384,6 @@ def _generate_dynamic_predicate_functions(class_name: str, namespace: Dict):
tmp.append(f"{f.name}_cltopy(raw_args[{idx}]), ")
args_cltopy= "".join(tmp)

if pdefn.is_tuple:
hash_eval_self = "hash(self._field_values)"
hash_eval_instance = "hash(instance._field_values)"
else:
hash_eval_self = f"""hash((self._sign, "{pdefn.name}", self._field_values))"""
hash_eval_instance = f"""hash((instance._sign, "{pdefn.name}", instance._field_values))"""

expansions = {"args_signature": args_signature,
"sign_check": sign_check,
"args": args,
Expand All @@ -2399,9 +2392,7 @@ def _generate_dynamic_predicate_functions(class_name: str, namespace: Dict):
"check_complex": check_complex,
"args_raw": args_raw,
"sign_check_unify": sign_check_unify,
"args_cltopy": args_cltopy,
"hash_evaluation_self": hash_eval_self,
"hash_evaluation_instance": hash_eval_instance}
"args_cltopy": args_cltopy}

bool_status = not pdefn.is_tuple or len(pdefn) > 0
template = PREDICATE_TEMPLATE.format(pdefn=pdefn, bool_status=bool_status)
Expand Down Expand Up @@ -2911,61 +2902,33 @@ def __neg__(self):
#--------------------------------------------------------------------------
def __eq__(self, other):
"""Overloaded boolean operator."""
if isinstance(other, self.__class__):
return self._field_values == other._field_values and \
self._sign == other._sign
if self._meta.is_tuple:
return self._field_values == other
elif isinstance(other, Predicate):
return False
if isinstance(other, Predicate):
return self._raw == other._raw
return NotImplemented

def __lt__(self, other):
"""Overloaded boolean operator."""

# If it is the same predicate class then compare the sign and fields
if isinstance(other, self.__class__):

# Negative literals are less than positive literals
if self._sign != other._sign: return self._sign < other._sign

return self._field_values < other._field_values

# If different predicates then compare the raw value
elif isinstance(other, Predicate):
if isinstance(other, Predicate):
return self._raw < other._raw

# Else an error
return NotImplemented

def __ge__(self, other):
def __le__(self, other):
"""Overloaded boolean operator."""
result = self.__lt__(other)
if result is NotImplemented: return NotImplemented
return not result
if isinstance(other, Predicate):
return self._raw <= other._raw
return NotImplemented

def __gt__(self, other):
"""Overloaded boolean operator."""

# If it is the same predicate class then compare the sign and fields
if isinstance(other, self.__class__):
# Positive literals are greater than negative literals
if self._sign != other._sign: return self._sign > other._sign

return self._field_values > other._field_values

# If different predicates then compare the raw value
if not isinstance(other, Predicate):
if isinstance(other, Predicate):
return self._raw > other._raw

# Else an error
return NotImplemented

def __le__(self, other):
def __ge__(self, other):
"""Overloaded boolean operator."""
result = self.__gt__(other)
if result is NotImplemented: return NotImplemented
return not result
if isinstance(other, Predicate):
return self._raw >= other._raw
return NotImplemented

def __hash__(self):
return self._hash
Expand All @@ -2985,13 +2948,12 @@ def __getstate__(self):
def __setstate__(self, newstate):
self._field_values = newstate["_field_values"]
self._sign = newstate["_sign"]
if self._meta.is_tuple: self._hash = hash(self._field_values)
else: self._hash = hash((self._sign, self._meta.name, self._field_values))

clingoargs=[]
for f,v in zip(self._meta, self._field_values):
clingoargs.append(v.symbol if f.defn.complex else f.defn.pytocl(v))
self._raw = Function(self._meta.name, clingoargs, self._sign)
self._hash = hash(self._raw)


#------------------------------------------------------------------------------
Expand Down
67 changes: 62 additions & 5 deletions clorm/orm/templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,11 @@ def __init__(self,
self._field_values = ({{%args%}})
self._hash = {{%hash_evaluation_self%}}
# Create the raw symbol
# Create the raw symbol and cache the hash
self._raw = Function("{pdefn.name}",
({{%args_raw%}}),
self._sign)
self._hash = hash(self._raw)
@classmethod
def _unify(cls: Type[_P], raw: AnySymbol) -> Optional[_P]:
Expand All @@ -86,10 +85,9 @@ def _unify(cls: Type[_P], raw: AnySymbol) -> Optional[_P]:
instance = cls.__new__(cls)
instance._raw = raw
instance._hash = None
instance._hash = hash(raw)
instance._sign = raw.positive
instance._field_values = ({{%args_cltopy%}})
instance._hash = {{%hash_evaluation_instance%}}
return instance
except (TypeError, ValueError):
return None
Expand All @@ -104,8 +102,67 @@ def __bool__(self):
def __len__(self):
return {pdefn.arity}
def __eq__(self, other):
if self.__class__ == other.__class__:
return (self._sign == other._sign) and (self._field_values == other._field_values)
return NotImplemented
def __lt__(self, other):
if self.__class__ == other.__class__:
return (self._sign < other._sign) or (self._field_values < other._field_values)
return NoImplemented
"""

PREDICATE_EQ_NON_TUPLE ="""
def __eq__(self, other):
if self.__class__ == other.__class__:
return (self._sign == other._sign) and (self._field_values == other._field_values)
return False if ininstance(other, Predicate) else NotImplemented
"""

PREDICATE_EQ_TUPLE ="""
def __eq__(self, other):
if self.__class__ == other.__class__:
return self._field_values == other._field_values
if isinstance(other, Predicate):
return self._field_values == other._field_values if other._meta.is_tuple else False
return other == self._field_values
"""

PREDICATE_CMP_NON_TUPLE=r"""
def __{name}__(self, other):
if self.__class__ == other.__class__:
# Note: rely on False < True (True > False) and evaluation order
return (self._sign {op} other._sign) or (self._field_values {op} other._field_values)
if isinstance(other, Predicate):
return self._raw {op} other._raw
return NotImplemented
"""

PREDICATE_CMP_TUPLE=r"""
def __{name}__(self, other):
if self.__class__ == other.__class__:
self._field_values {op} other._field_values
if isinstance(other, Predicate):
return self._raw {op} other._raw
return other {op} self._field_values
"""

def make_predicate_cmp(name, op, is_tuple=False):
if is_tuple:
return PREDICATE_CMP_NON_TUPLE.format(name=name, op=op)
return PREDICATE_CMP_TUPLE.format(name=name, op=op)




CHECK_SIGN_TEMPLATE = r"""
# Check if the sign is allowed
if self._sign != {sign}:
Expand Down
Loading

0 comments on commit 931d326

Please sign in to comment.