Skip to content

Commit

Permalink
Main change to the semantics of Predicate comparison
Browse files Browse the repository at this point in the history
Predicate comparison now uses the underlying clingo.Symbol object. Will allow querying with
ordering to work even when the types are incomparable because Symbol objects are always
comparable.

Still needs to update some of the querying behaviour
  • Loading branch information
daveraja committed May 5, 2024
1 parent 846b3ab commit fd30945
Show file tree
Hide file tree
Showing 6 changed files with 266 additions and 163 deletions.
198 changes: 89 additions & 109 deletions clorm/orm/core.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
# -----------------------------------------------------------------------------
# Implementation of the core part of the Clorm ORM. In particular this provides
# the base classes and metaclasses for the definition of fields, predicates,
# predicate paths, and the specification of query conditions. Note: query
# condition specification is provided here because the predicate path comparison
# operators are overloads to return these objects. However, the rest of the
# query API is specified with the FactBase and select querying mechanisms
# (see factbase.py).
# ------------------------------------------------------------------------------
# -------------------------------------------------------------------------------------------
# Implementation of the core part of the Clorm ORM. In particular this provides the base
# classes and metaclasses for the definition of fields, predicates, predicate paths, and the
# specification of query conditions. Note: query condition specification is provided here
# because the predicate path comparison operators are overloads to return these
# objects. However, the rest of the query API is specified with the FactBase and select
# querying mechanisms (see factbase.py).
# -------------------------------------------------------------------------------------------

# -------------------------------------------------------------------------------------------
# NOTE: 20242028 the semantics for the comparison operators has changed. Instead of using the
# python field representation we use the underlying clingo symbol object. The symbol object is
# well defined for any comparison between symbols, whereas tuples are only well defined if the
# types of the individual parameters are compatible. So this change leads to more natural
# behaviour for the queries. Note: users should avoid defining unintuitive fields (for example
# a swap field that changes the sign of an int) to avoid unintuitive Python behaviour.
# -------------------------------------------------------------------------------------------


from __future__ import annotations

Expand Down Expand Up @@ -1341,41 +1350,42 @@ class DateField(StringField):
``FactBase```. Defaults to ``False``.
"""

def __init__(self, default: Any = MISSING, index: Any = MISSING) -> None:
self._index = index if index is not MISSING else False

if default is MISSING:
self._default = (False, None)
return

self._default = (True, default)
cmplx = self.complex

# Check and convert the default to a valid value. Note: if the default
# is a callable then we can't do this check because it could break a
# counter type procedure.
if callable(default) or (cmplx and isinstance(default, cmplx)):
return
def _process_basic_value(v):
return v

try:
if cmplx:
def _process_cmplx_value(v):
if isinstance(v, cmplx):
return v
if isinstance(v, tuple) or (isinstance(v, Predicate) and v.meta.is_tuple):
return cmplx(*v)
raise TypeError(f"Value {v} ({type(v)}) cannot be converted to type {cmplx}")

def _instance_from_tuple(v):
if isinstance(v, tuple) or (isinstance(v, Predicate) and v.meta.is_tuple):
return cmplx(*v)
raise TypeError(f"Value {v} ({type(v)}) is not a tuple")
_process_value = _process_basic_value if cmplx is None else _process_cmplx_value

# If the default is not a factory function than make sure the value can be converted to
# clingo without error.
if not callable(default):
try:
self._default = (True, _process_value(default))
self.pytocl(self._default[1])
except (TypeError, ValueError):
raise TypeError(
'Invalid default value "{}" for {}'.format(default, type(self).__name__)
)
else:
def _process_default():
return _process_value(default())
self._default = (True, _process_default)

if cmplx.meta.is_tuple:
self._default = (True, _instance_from_tuple(default))
else:
raise ValueError("Bad default")
else:
self.pytocl(default)
except (TypeError, ValueError):
raise TypeError(
'Invalid default value "{}" for {}'.format(default, type(self).__name__)
)

@staticmethod
@abc.abstractmethod
Expand Down Expand Up @@ -1503,11 +1513,11 @@ def field(
raise TypeError(f"{basefield} can just be of Type '{BaseField}' or '{Sequence}'")


# ------------------------------------------------------------------------------
# RawField is a sub-class of BaseField for storing Symbol or NoSymbol
# objects. The behaviour of Raw with respect to using clingo.Symbol or
# noclingo.NoSymbol is modified by the symbol mode (get_symbol_mode())
# ------------------------------------------------------------------------------
# ------------------------------------------------------------------------------------------
# RawField is a sub-class of BaseField for storing Symbol objects. The behaviour of Raw with
# respect to using clingo.Symbol or noclingo.Symbol is modified by the symbol mode
# (get_symbol_mode())
# ------------------------------------------------------------------------------------------


class Raw(object):
Expand Down Expand Up @@ -2441,13 +2451,9 @@ def get_field_definition(defn: Any, module: str = "") -> BaseField:


def _create_complex_term(defn: Any, default_value: Any = MISSING, module: str = "") -> BaseField:
# NOTE: I was using a dict rather than OrderedDict which just happened to
# work. Apparently, in Python 3.6 this was an implmentation detail and
# Python 3.7 it is a language specification (see:
# https://stackoverflow.com/questions/1867861/how-to-keep-keys-values-in-same-order-as-declared/39537308#39537308).
# However, since Clorm is meant to be Python 3.5 compatible change this to
# use an OrderedDict.
# proto = { "arg{}".format(i+1) : get_field_definition(d) for i,d in enumerate(defn) }
# NOTE: relies on a dict preserving insertion order - this is true from Python 3.7+. Python
# 3.7 is already end-of-life so there is no longer a reason to use OrderedDict.
#proto = {f"arg{idx+1}": get_field_definition(dn) for idx, dn in enumerate(defn)}
proto: Dict[str, Any] = collections.OrderedDict(
[(f"arg{i+1}", get_field_definition(d, module)) for i, d in enumerate(defn)]
)
Expand Down Expand Up @@ -2705,13 +2711,16 @@ def _generate_dynamic_predicate_functions(class_name: str, namespace: Dict) -> N

gdict = {
"Predicate": Predicate,
"Symbol": Symbol,
"Function": Function,
"MISSING": MISSING,
"AnySymbol": AnySymbol,
"Type": Type,
"Any": Any,
"Optional": Optional,
"Sequence": Sequence,
"_P": _P,
"PREDICATE_IS_TUPLE": pdefn.is_tuple,
}

for f in pdefn:
Expand Down Expand Up @@ -2786,21 +2795,26 @@ def _generate_dynamic_predicate_functions(class_name: str, namespace: Dict) -> N

template = PREDICATE_TEMPLATE.format(pdefn=pdefn)
predicate_functions = expand_template(template, **expansions)
# print(f"INIT:\n\n{predicate_functions}\n\n")
# print(f"INIT {class_name}:\n\n{predicate_functions}\n\n")

ldict: Dict[str, Any] = {}
exec(predicate_functions, gdict, ldict)

init_doc_args = f"{args_signature}*, sign=True, raw=None"
predicate_init = ldict["__init__"]
predicate_init.__name__ = "__init__"
predicate_init.__doc__ = f"{class_name}({init_doc_args})"
predicate_unify = ldict["_unify"]
predicate_unify.__name__ = "_unify"
predicate_unify.__doc__ = PREDICATE_UNIFY_DOCSTRING
def _set_fn(fname: str, docstring: str):
tmp = ldict[fname]
tmp.__name__ = fname
tmp.__doc = docstring
namespace[fname] = tmp

namespace["__init__"] = predicate_init
namespace["_unify"] = predicate_unify
# Assign the __init__, _unify, __hash__, and appropriate comparison functions
_set_fn("__init__", f"{class_name}({args_signature}*, sign=True, raw=None)")
_set_fn("_unify", PREDICATE_UNIFY_DOCSTRING)
_set_fn("__hash__", "Hash operator")
_set_fn("__eq__", "Equality operator")
_set_fn("__lt__", "Less than operator")
_set_fn("__le__", "Less than or equal operator")
_set_fn("__gt__", "Greater than operator")
_set_fn("__ge__", "Greater than operator")


# ------------------------------------------------------------------------------
Expand Down Expand Up @@ -3160,6 +3174,7 @@ def _cltopy(v):
# ------------------------------------------------------------------------------
# A Metaclass for the Predicate base class
# ------------------------------------------------------------------------------

@__dataclass_transform__(field_descriptors=(field,))
class _PredicateMeta(type):
if TYPE_CHECKING:
Expand Down Expand Up @@ -3244,8 +3259,12 @@ def __iter__(self) -> Iterator[PredicatePath]:
# underlying Symbol object.
# ------------------------------------------------------------------------------

# Mixin class to be able to use both MetaClasses
class _AbstractPredicateMeta(abc.ABCMeta, _PredicateMeta):
pass


class Predicate(object, metaclass=_PredicateMeta):
class Predicate(object, metaclass=_AbstractPredicateMeta):
"""Abstract base class to encapsulate an ASP predicate or complex term.
This is the heart of the ORM model for defining the mapping of a predicate
Expand Down Expand Up @@ -3324,7 +3343,7 @@ def _unify(
def symbol(self):
"""Returns the Symbol object corresponding to the fact.
The type of the object maybe either a ``clingo.Symbol`` or ``noclingo.NoSymbol``.
The type of the object maybe either a ``clingo.Symbol`` or ``noclingo.Symbol``.
"""
return self._raw

Expand Down Expand Up @@ -3413,74 +3432,35 @@ def __neg__(self):
# --------------------------------------------------------------------------
# Overloaded operators
# --------------------------------------------------------------------------
@abc.abstractmethod
def __eq__(self, other):
"""Overloaded boolean operator."""
if isinstance(other, self.__class__):
return self._field_values == other._field_values and self._sign == other._sign
if self.meta.is_tuple:
return self._field_values == other
elif isinstance(other, Predicate):
return False
return NotImplemented
raise NotImplementedError("Predicate.__eq__() must be overriden")

@abc.abstractmethod
def __lt__(self, other):
"""Overloaded boolean operator."""
raise NotImplementedError("Predicate.__lt__() must be overriden")

# If it is the same predicate class then compare the sign and fields
if isinstance(other, self.__class__):

# Negative literals are less than positive literals
if self.sign != other.sign:
return self.sign < other.sign

return self._field_values < other._field_values

# If different predicates then compare the raw value
elif isinstance(other, Predicate):
return self.raw < other.raw

# Else an error
return NotImplemented
def __le__(self, other):
"""Overloaded boolean operator."""
raise NotImplementedError("Predicate.__le__() must be overriden")

@abc.abstractmethod
def __ge__(self, other):
"""Overloaded boolean operator."""
result = self.__lt__(other)
if result is NotImplemented:
return NotImplemented
return not result
raise NotImplementedError("Predicate.__ge__() must be overriden")

@abc.abstractmethod
def __gt__(self, other):
"""Overloaded boolean operator."""
raise NotImplementedError("Predicate.__gt__() must be overriden")

# If it is the same predicate class then compare the sign and fields
if isinstance(other, self.__class__):
# Positive literals are greater than negative literals
if self.sign != other.sign:
return self.sign > other.sign

return self._field_values > other._field_values

# If different predicates then compare the raw value
if not isinstance(other, Predicate):
return self.raw > other.raw

# Else an error
return NotImplemented

def __le__(self, other):
"""Overloaded boolean operator."""
result = self.__gt__(other)
if result is NotImplemented:
return NotImplemented
return not result

@abc.abstractmethod
def __hash__(self):
if self._hash is None:
if self.meta.is_tuple:
self._hash = hash(self._field_values)
else:
self._hash = hash((self.meta.name, self._field_values))
return self._hash
"""Overload the hash function."""
raise NotImplementedError("Predicate.__hash__() must be overriden")

def __str__(self):
"""Returns the Predicate as the string representation of an ASP fact."""
Expand Down
Loading

0 comments on commit fd30945

Please sign in to comment.