Skip to content

Commit

Permalink
Big refactoring to lay the basis for interval querying (swisscom#118)
Browse files Browse the repository at this point in the history
All existing tests pass.
  • Loading branch information
brki committed Aug 25, 2016
1 parent c911bd3 commit d034838
Showing 1 changed file with 171 additions and 52 deletions.
223 changes: 171 additions & 52 deletions versions/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,65 @@ def get_utc_now():
return datetime.datetime.utcnow().replace(tzinfo=utc)


QueryTime = namedtuple('QueryTime', 'time active')
class SimpleEqualityMixin(object):
def __eq__(self, other):
return type(self) == type(other) and self.__dict__ == other.__dict__

def __ne__(self, other):
return not self.__eq__(other)


class QueryTimeInterval(SimpleEqualityMixin):
start_time = None
end_time = None
unique = False
"""If true, only include last version, otherwise include all versions in interval"""

def __init__(self, start_time, end_time, unique=False):
if not (isinstance(start_time, datetime) and isinstance(end_time, datetime)):
raise ValueError("start_time and end_time parameters need to be datetime objects")
if not start_time <= datetime:
raise ValueError("start_time must not be later than end_time")
self.start_time = start_time
self.end_time = end_time
self.unique = unique

def _clone(self):
return QueryTimeInterval(start_time=self.start_time, end_time=self.end_time, unique=self.unique)


class QueryTime(SimpleEqualityMixin):
TYPE_POINT_IN_TIME = 'point_in_time'
TYPE_INTERVAL = 'interval'

type = TYPE_POINT_IN_TIME
"""TYPE_POINT_IN_TIME or TYPE_INTERVAL"""

active = True
"""Whether or not the query should restrict to a point in time or interval."""

time = None
"""
datetime if type is TYPE_POINT_IN_TIME. If type is TYPE_POINT_IN_TIME and this value is None,
then query current objects only.
"""

interval = None
"""A QueryTimeInterval if type is TYPE_INTERVAL, otherwise None"""

def __init__(self, type=TYPE_POINT_IN_TIME, active=False, time=None, interval=None):
self.type = type
self.active = active
self.time = time
if type == self.TYPE_INTERVAL and not interval:
raise ValueError('An "interval" paramater is necessary if the "type" is TYPE_INTERVAL')
self.interval = interval

def _clone(self):
interval = self.interval
if interval:
interval = interval._clone()
return QueryTime(type=self.type, active=self.active, time=self.time, interval=interval)


class ForeignKeyRequiresValueError(ValueError):
Expand All @@ -78,7 +136,7 @@ def get_queryset(self):
"""
qs = VersionedQuerySet(self.model, using=self._db)
if hasattr(self, 'instance') and hasattr(self.instance, '_querytime'):
qs.querytime = self.instance._querytime
qs.querytime = self.instance._querytime._clone()
return qs

def as_of(self, time=None):
Expand All @@ -89,6 +147,15 @@ def as_of(self, time=None):
"""
return self.get_queryset().as_of(time)

def with_querytime(self, querytime):
"""
Filters Versionables with the given QueryTime
:param QueryTime querytime: object describing how query should be time restricted.
:return: A QuerySet containing the base for a timestamped query.
"""
return self.get_queryset().with_querytime(querytime)

def next_version(self, object, relations_as_of='end'):
"""
Return the next version of the given object.
Expand Down Expand Up @@ -225,7 +292,7 @@ def adjust_version_as_of(version, relations_as_of):
)
version.as_of = as_of
elif relations_as_of is None:
version._querytime = QueryTime(time=None, active=False)
version._querytime = QueryTime(type=QueryTime.TYPE_POINT_IN_TIME, active=False)
else:
raise TypeError("as_of parameter must be 'start', 'end', None, or datetime object")

Expand Down Expand Up @@ -305,7 +372,6 @@ def as_sql(self, qn, connection):
except AttributeError:
# Django 1.6 handles compilers as instancemethods
_query = qn.__self__.query
query_time = _query.querytime.time
apply_query_time = _query.querytime.active
alias_map = _query.alias_map
# In Django 1.6 & 1.7, use the join_map to know, what *table* gets joined to which
Expand All @@ -317,7 +383,7 @@ def as_sql(self, qn, connection):
self._set_child_joined_alias(child, alias_map)
if apply_query_time:
# Add query parameters that have not been added till now
child.set_as_of(query_time)
child.set_querytime(_query.querytime)
else:
# Remove the restriction if it's not required
child.sqls = []
Expand Down Expand Up @@ -382,13 +448,12 @@ def __init__(self, historic_sql, current_sql, alias, remote_alias):
self.current_sql = current_sql
self.alias = alias
self.related_alias = remote_alias
self._as_of_time_set = False
self.as_of_time = None
self.querytime = None
self._joined_alias = None

def set_as_of(self, as_of_time):
self.as_of_time = as_of_time
self._as_of_time_set = True
def set_querytime(self, querytime):
self.querytime = querytime._clone()
self._querytime_set = True

def set_joined_alias(self, joined_alias):
"""
Expand All @@ -404,23 +469,25 @@ def as_sql(self, qn=None, connection=None):
params = []

# Fail fast for inacceptable cases
if self._as_of_time_set and not self._joined_alias:
raise ValueError("joined_alias is not set, but as_of is; this is a conflict!")
if self.querytime and not self._joined_alias:
raise ValueError("joined_alias is not set, but querytime is; this is a conflict!")

# Set the SQL string in dependency of whether as_of_time was set or not
if self._as_of_time_set:
if self.as_of_time:
sql = self.historic_sql
params = [self.as_of_time] * 2
# 2 is the number of occurences of the timestamp in an as_of-filter expression
else:
# If as_of_time was set to None, we're dealing with a query for "current" values
sql = self.current_sql
if self.querytime:
# TODO (interval-querying): figure out what to do here for TYPE_INTERVAL QueryTime
if self.querytime.type == QueryTime.TYPE_POINT_IN_TIME:
if self.querytime.time:
sql = self.historic_sql
params = [self.querytime.time] * 2
# 2 is the number of occurences of the timestamp in an as_of-filter expression
else:
# If as_of_time was set to None, we're dealing with a query for "current" values
sql = self.current_sql
else:
# No as_of_time has been set; Perhaps, as_of was not part of the query -> That's OK
# No querytime has been set; Perhaps, as_of was not part of the query -> That's OK
pass

# By here, the sql string is defined if an as_of_time was provided
# By here, the sql string is defined if a querytime was provided
if self._joined_alias:
sql = sql.format(alias=self._joined_alias)

Expand All @@ -444,12 +511,12 @@ class VersionedQuery(Query):
def __init__(self, *args, **kwargs):
kwargs['where'] = VersionedWhereNode
super(VersionedQuery, self).__init__(*args, **kwargs)
self.querytime = QueryTime(time=None, active=False)
self.querytime = QueryTime(type=QueryTime.TYPE_POINT_IN_TIME, active=False)

def clone(self, *args, **kwargs):
_clone = super(VersionedQuery, self).clone(*args, **kwargs)
try:
_clone.querytime = self.querytime
_clone.querytime = self.querytime._clone()
except AttributeError:
# If the caller is using clone to create a different type of Query, that's OK.
# An example of this is when creating or updating an object, this method is called
Expand All @@ -464,19 +531,42 @@ def get_compiler(self, *args, **kwargs):
object to work (they are attached to a queryset; filter() returns a new queryset).
"""
if self.querytime.active and (not hasattr(self, '_querytime_filter_added') or not self._querytime_filter_added):
time = self.querytime.time
if time is None:
self.add_q(Q(version_end_date__isnull=True))
if self.querytime.type == QueryTime.TYPE_POINT_IN_TIME:
self.add_point_in_time_filter()
elif self.querytime.type == QueryTime.TYPE_INTERVAL:
self.add_interval_filter()
else:
self.add_q(
(Q(version_end_date__gt=time) | Q(version_end_date__isnull=True))
& Q(version_start_date__lte=time)
)
raise RuntimeError("Unrecognized QueryTime.type")

# Ensure applying these filters happens only a single time (even if it doesn't falsify the query, it's
# just not very comfortable to read)
self._querytime_filter_added = True
return super(VersionedQuery, self).get_compiler(*args, **kwargs)

def add_point_in_time_filter(self):
time = self.querytime.time
if time is None:
self.add_q(Q(version_end_date__isnull=True))
else:
self.add_q(
(Q(version_end_date__gt=time) | Q(version_end_date__isnull=True))
& Q(version_start_date__lte=time)
)

def add_interval_filter(self):
interval = self.querytime.interval

current_version_intersects = Q(version_end_date__isnull=True, version_start_date__lte=interval.end_time)
encompasses_interval = Q(version_start_date__lt=interval.start_time, version_end_date__gt=interval.end_time)
intersects_interval = Q(
version_start_date__gte=interval.start_time,
version_start_date__lte=interval.end_time) | Q(
version_end_date__gte=interval.start_time,
version_end_date__lte=interval.end_time
)
terminated_version_matches = Q(version_end_date__isnull=False) & (encompasses_interval | intersects_interval)
self.add_q(current_version_intersects | terminated_version_matches)

def build_filter(self, filter_expr, **kwargs):
"""
When a query is filtered with an expression like .filter(team=some_team_object),
Expand Down Expand Up @@ -533,7 +623,7 @@ def __init__(self, model=None, query=None, *args, **kwargs):
if not query:
query = VersionedQuery(model)
super(VersionedQuerySet, self).__init__(model=model, query=query, *args, **kwargs)
self.querytime = QueryTime(time=None, active=False)
self.querytime = QueryTime(type=QueryTime.TYPE_POINT_IN_TIME, active=False)

@property
def querytime(self):
Expand Down Expand Up @@ -600,7 +690,7 @@ def _clone(self, *args, **kwargs):
kwargs['klass'] = klass

clone = super(VersionedQuerySet, self)._clone(**kwargs)
clone.querytime = self.querytime
clone.querytime = self.querytime._clone()
return clone

def _set_item_querytime(self, item, type_check=True):
Expand All @@ -611,9 +701,9 @@ def _set_item_querytime(self, item, type_check=True):
:return: Returns the item itself with the time set
"""
if isinstance(item, Versionable):
item._querytime = self.querytime
item._querytime = self.querytime._clone()
elif isinstance(item, VersionedQuerySet):
item.querytime = self.querytime
item.querytime = self.querytime._clone()
elif isinstance(self, ValuesQuerySet):
# When we are dealing with a ValueQuerySet there is no point in
# setting the query_time as we are returning an array of values
Expand All @@ -627,11 +717,21 @@ def _set_item_querytime(self, item, type_check=True):
def as_of(self, qtime=None):
"""
Sets the time for which we want to retrieve an object.
The qtime parameter can have one of these forms:
- the UTC date and time (a datetime object)
- None (this will limit the QuerySet results to the current objects)
- a tuple (start_utc_datetime, end_utc_datetime)
:param qtime: The UTC date and time; if None then use the current state (where version_end_date = NULL)
:return: A VersionedQuerySet
"""
clone = self._clone()
clone.querytime = QueryTime(time=qtime, active=True)
clone.querytime = QueryTime(type=QueryTime.TYPE_POINT_IN_TIME, active=True, time=qtime)
return clone

def with_querytime(self, querytime):
clone = self._clone()
clone.querytime = querytime._clone()
return clone

def delete(self):
Expand Down Expand Up @@ -700,6 +800,7 @@ def get_extra_restriction(self, where_class, alias, remote_alias):
:return: SQL conditional statement
:rtype: WhereNode
"""
# TODO (interval-querying): look into how to handle this for the case of an TYPE_INTERVAL querytime ...
historic_sql = '''{alias}.version_start_date <= %s
AND ({alias}.version_end_date > %s OR {alias}.version_end_date is NULL )'''
current_sql = '''{alias}.version_end_date is NULL'''
Expand Down Expand Up @@ -849,15 +950,16 @@ def __get__(self, instance, instance_type=None):
+ str(type(current_elt))
+ ", which is not a subclass of Versionable")

manager = current_elt.__class__.objects
if hasattr(instance, '_querytime'):
# If current_elt matches the instance's querytime, there's no need to make a database query.
if Versionable.matches_querytime(current_elt, instance._querytime):
current_elt._querytime = instance._querytime
current_elt._querytime = instance._querytime._clone()
return current_elt

return current_elt.__class__.objects.as_of(instance._querytime.time).get(identity=current_elt.identity)
return manager.with_querytime(instance._querytime).get(identity=current_elt.identity)
else:
return current_elt.__class__.objects.current.get(identity=current_elt.identity)
return manager.current.get(identity=current_elt.identity)


class VersionedForeignRelatedObjectsDescriptor(ForeignRelatedObjectsDescriptor):
Expand Down Expand Up @@ -885,9 +987,10 @@ def get_queryset(self):
queryset = super(VersionedRelatedManager, self).get_queryset()
# Do not set the query time if it is already correctly set. queryset.as_of() returns a clone
# of the queryset, and this will destroy the prefetched objects cache if it exists.
if isinstance(queryset,
VersionedQuerySet) and self.instance._querytime.active and queryset.querytime != self.instance._querytime:
queryset = queryset.as_of(self.instance._querytime.time)
if isinstance(queryset, VersionedQuerySet):
querytime = self.instance._querytime
if querytime.active and querytime != queryset.querytime:
queryset = queryset.with_querytime(querytime)
return queryset

def add(self, *objs):
Expand Down Expand Up @@ -970,11 +1073,11 @@ def get_queryset(self):
available).
Long story short, apply the temporal validity filter also to the intermediary model.
"""

queryset = super(VersionedManyRelatedManager, self).get_queryset()
if hasattr(queryset, 'querytime'):
if self.instance._querytime.active and self.instance._querytime != queryset.querytime:
queryset = queryset.as_of(self.instance._querytime.time)
querytime = self.instance._querytime
if querytime.active and querytime != queryset.querytime:
queryset = queryset.with_querytime(querytime)
return queryset

def _remove_items(self, source_field_name, target_field_name, *objs):
Expand Down Expand Up @@ -1241,7 +1344,7 @@ class Meta:
def __init__(self, *args, **kwargs):
super(Versionable, self).__init__(*args, **kwargs)
# _querytime is for library-internal use.
self._querytime = QueryTime(time=None, active=False)
self._querytime = QueryTime(type=QueryTime.TYPE_POINT_IN_TIME, active=False)

def delete(self, using=None):
using = using or router.db_for_write(self.__class__, instance=self)
Expand Down Expand Up @@ -1304,7 +1407,7 @@ def as_of(self):

@as_of.setter
def as_of(self, time):
self._querytime = QueryTime(time=time, active=True)
self._querytime = QueryTime(type=QueryTime.TYPE_POINT_IN_TIME, active=True, time=time)

@staticmethod
def uuid():
Expand Down Expand Up @@ -1527,11 +1630,27 @@ def matches_querytime(instance, querytime):
if not querytime.active:
return True

if not querytime.time:
return instance.version_end_date is None

return (instance.version_start_date <= querytime.time
and (instance.version_end_date is None or instance.version_end_date > querytime.time))
if querytime.type == QueryTime.TYPE_POINT_IN_TIME:
if not querytime.time:
return instance.version_end_date is None

return (instance.version_start_date <= querytime.time
and (instance.version_end_date is None or instance.version_end_date > querytime.time))
elif querytime.type == QueryTime.TYPE_POINT_IN_TIME:
start = querytime.interval.start_time
end = querytime.interval.end_time
return (
# Instance has start or end in interval:
start >= instance.version_start_date < end
or (instance.version_end_date is not None and start >= instance.version_end_date < end)
) or (
# Instance started before and ended after the interval (or never ended):
start > instance.version_start_date
and (instance.version_end_date is None or end <= instance.version_end_date)
)
else:
raise RuntimeError(
"Unexpected value for querytime.type; it should be either TYPE_POINT_IN_TIME or TYPE_INTERVAL")


class VersionedManyToManyModel(object):
Expand Down

0 comments on commit d034838

Please sign in to comment.