Skip to content

Commit

Permalink
Merge pull request #16 from Overseas-Student-Living/auto-join
Browse files Browse the repository at this point in the history
Auto join
  • Loading branch information
mattbennett authored Feb 12, 2018
2 parents 07e38e1 + 6680bb8 commit 1e6408c
Show file tree
Hide file tree
Showing 9 changed files with 677 additions and 62 deletions.
52 changes: 44 additions & 8 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,36 @@ It is also possible to filter queries that contain multiple models, including jo
result = filtered_query.all()
You must specify the `model` key in each filter if the query is against more than one model.
Note that we can also apply filters to queries defined by fields or functions:
`apply_filters` will attempt to automatically join models to `query` if they're not already present and a model-specific filter is supplied. For example, the value of `filtered_query` in the following two code blocks is identical:
.. code-block:: python
query = session.query(Foo).join(Bar) # join pre-applied to query
filter_spec = [
{'model': 'Foo', field': 'name', 'op': '==', 'value': 'name_1'},
{'model': 'Bar', field': 'count', 'op': '>=', 'value': 5},
]
filtered_query = apply_filters(query, filter_spec)
.. code-block:: python
query = session.query(Foo) # join to Bar will be automatically applied
filter_spec = [
{field': 'name', 'op': '==', 'value': 'name_1'},
{'model': 'Bar', field': 'count', 'op': '>=', 'value': 5},
]
filtered_query = apply_filters(query, filter_spec)
The automatic join is only possible if sqlalchemy can implictly determine the condition for the join, for example because of a foreign key relationship.
Automatic joins allow flexibility for clients to filter and sort by related objects without specifying all possible joins on the server beforehand.
Note that first filter of the second block does not specify a model. It is implictly applied to the `Foo` model because that is the only model in the original query passed to `apply_filters`.
It is also possible to apply filters to queries defined by fields or functions:
.. code-block:: python
Expand Down Expand Up @@ -118,17 +145,21 @@ The default SQLAlchemy join is lazy, meaning that columns from the joined table
`apply_loads` cannot be applied to columns that are loaded as `joined eager loads <http://docs.sqlalchemy.org/en/latest/orm/loading_relationships.html#joined-eager-loading>`_. This is because a joined eager load does not add the joined model to the original query, as explained `here <http://docs.sqlalchemy.org/en/latest/orm/loading_relationships.html#the-zen-of-joined-eager-loading>`_
The following would produce an error:
The following would not prevent all columns from Bar being eagerly loaded:
.. code-block:: python
query = session.query(Foo).options(joinedload(Bar))
query = session.query(Foo).options(joinedload(Foo.bar))
load_spec = [
{'model': 'Foo', 'fields': ['name']}
{'model': 'Bar', 'fields': ['count']} # invalid
{'model': 'Bar', 'fields': ['count']}
]
query = apply_loads(query, load_spec) # error! query does not contain model Bar
query = apply_loads(query, load_spec)
.. sidebar:: Automatic Join
In fact, what happens here is that `Bar` is automatically joined to `query`, because it is determined that `Bar` is not part of the original query. The `load_spec` therefore has no effect because the automatic join
results in lazy evaluation.
If you wish to perform a joined load with restricted columns, you must specify the columns as part of the joined load, rather than with `apply_loads`:
Expand Down Expand Up @@ -159,6 +190,11 @@ Sort
result = sorted_query.all()
`apply_sort` will attempt to automatically join models to `query` if they're not already present and a model-specific sort is supplied. The behaviour is the same as in `apply_filters`.
This allows flexibility for clients to sort by fields on related objects without specifying all possible joins on the server beforehand.
Pagination
----------
Expand Down Expand Up @@ -193,7 +229,7 @@ following format:
# ...
]
The `model` key is optional if the query being filtered only applies to one model.
The `model` key is optional if the original query being filtered only applies to one model.
If there is only one filter, the containing list may be omitted:
Expand Down Expand Up @@ -263,7 +299,7 @@ applied sequentially:
Where ``field`` is the name of the field that will be sorted using the
provided ``direction``.
The `model` key is optional if the query being sorted only applies to one model.
The `model` key is optional if the original query being sorted only applies to one model.
Running tests
Expand Down
90 changes: 69 additions & 21 deletions sqlalchemy_filters/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sqlalchemy import and_, or_, not_

from .exceptions import BadFilterFormat
from .models import Field, get_model_from_spec
from .models import Field, auto_join, get_model_from_spec, get_default_model


BooleanFunction = namedtuple(
Expand Down Expand Up @@ -59,36 +59,67 @@ def __init__(self, operator=None):

class Filter(object):

def __init__(self, filter_spec, query):
def __init__(self, filter_spec):
self.filter_spec = filter_spec

try:
field_name = filter_spec['field']
filter_spec['field']
except KeyError:
raise BadFilterFormat('`field` is a mandatory filter attribute.')
except TypeError:
raise BadFilterFormat(
'Filter spec `{}` should be a dictionary.'.format(filter_spec)
)

model = get_model_from_spec(filter_spec, query)

self.field = Field(model, field_name)
self.operator = Operator(filter_spec.get('op'))
self.value = filter_spec.get('value')
self.value_present = True if 'value' in filter_spec else False

if not self.value_present and self.operator.arity == 2:
value_present = True if 'value' in filter_spec else False
if not value_present and self.operator.arity == 2:
raise BadFilterFormat('`value` must be provided.')

def format_for_sqlalchemy(self):
function = self.operator.function
arity = self.operator.arity
field = self.field.get_sqlalchemy_field()
def get_named_models(self):
if "model" in self.filter_spec:
return {self.filter_spec['model']}
return set()

def format_for_sqlalchemy(self, query, default_model):
filter_spec = self.filter_spec
operator = self.operator
value = self.value

model = get_model_from_spec(filter_spec, query, default_model)

function = operator.function
arity = operator.arity

field_name = self.filter_spec['field']
field = Field(model, field_name)
sqlalchemy_field = field.get_sqlalchemy_field()

if arity == 1:
return function(field)
return function(sqlalchemy_field)

if arity == 2:
return function(field, self.value)
return function(sqlalchemy_field, value)


class BooleanFilter:

def __init__(self, function, *filters):
self.function = function
self.filters = filters

def get_named_models(self):
models = set()
for filter in self.filters:
models.update(filter.get_named_models())
return models

def format_for_sqlalchemy(self, query, default_model):
return self.function(*[
filter.format_for_sqlalchemy(query, default_model)
for filter in self.filters
])


def _is_iterable_filter(filter_spec):
Expand All @@ -100,12 +131,12 @@ def _is_iterable_filter(filter_spec):
)


def _build_sqlalchemy_filters(filter_spec, query):
def build_filters(filter_spec):
""" Recursively process `filter_spec` """

if _is_iterable_filter(filter_spec):
return list(chain.from_iterable(
_build_sqlalchemy_filters(item, query) for item in filter_spec
build_filters(item) for item in filter_spec
))

if isinstance(filter_spec, dict):
Expand Down Expand Up @@ -134,12 +165,19 @@ def _build_sqlalchemy_filters(filter_spec, query):
)
)
return [
boolean_function.sqlalchemy_fn(
*_build_sqlalchemy_filters(fn_args, query)
BooleanFilter(
boolean_function.sqlalchemy_fn, *build_filters(fn_args)
)
]

return [Filter(filter_spec, query).format_for_sqlalchemy()]
return [Filter(filter_spec)]


def get_named_models(filters):
models = set()
for filter in filters:
models.update(filter.get_named_models())
return models


def apply_filters(query, filter_spec):
Expand Down Expand Up @@ -177,7 +215,17 @@ def apply_filters(query, filter_spec):
The :class:`sqlalchemy.orm.Query` instance after all the filters
have been applied.
"""
sqlalchemy_filters = _build_sqlalchemy_filters(filter_spec, query)
filters = build_filters(filter_spec)

default_model = get_default_model(query)

filter_models = get_named_models(filters)
query = auto_join(query, *filter_models)

sqlalchemy_filters = [
filter.format_for_sqlalchemy(query, default_model)
for filter in filters
]

if sqlalchemy_filters:
query = query.filter(*sqlalchemy_filters)
Expand Down
44 changes: 35 additions & 9 deletions sqlalchemy_filters/loads.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from sqlalchemy.orm import Load

from .exceptions import BadLoadFormat
from .models import Field, get_model_from_spec
from .models import Field, auto_join, get_model_from_spec, get_default_model


class LoadOnly(object):

def __init__(self, load_spec, query):
def __init__(self, load_spec):
self.load_spec = load_spec

try:
field_names = load_spec['fields']
except KeyError:
Expand All @@ -18,17 +20,34 @@ def __init__(self, load_spec, query):
'Load spec `{}` should be a dictionary.'.format(load_spec)
)

self.model = get_model_from_spec(load_spec, query)
self.fields = [
Field(self.model, field_name) for field_name in field_names
self.field_names = field_names

def get_named_models(self):
if "model" in self.load_spec:
return {self.load_spec['model']}
return set()

def format_for_sqlalchemy(self, query, default_model):
load_spec = self.load_spec
field_names = self.field_names

model = get_model_from_spec(load_spec, query, default_model)
fields = [
Field(model, field_name) for field_name in field_names
]

def format_for_sqlalchemy(self):
return Load(self.model).load_only(
*[field.get_sqlalchemy_field() for field in self.fields]
return Load(model).load_only(
*[field.get_sqlalchemy_field() for field in fields]
)


def get_named_models(loads):
models = set()
for load in loads:
models.update(load.get_named_models())
return models


def apply_loads(query, load_spec):
"""Apply load restrictions to a :class:`sqlalchemy.orm.Query` instance.
Expand Down Expand Up @@ -62,8 +81,15 @@ def apply_loads(query, load_spec):
if isinstance(load_spec, dict):
load_spec = [load_spec]

loads = [LoadOnly(item) for item in load_spec]

default_model = get_default_model(query)

load_models = get_named_models(loads)
query = auto_join(query, *load_models)

sqlalchemy_loads = [
LoadOnly(item, query).format_for_sqlalchemy() for item in load_spec
load.format_for_sqlalchemy(query, default_model) for load in loads
]
if sqlalchemy_loads:
query = query.options(*sqlalchemy_loads)
Expand Down
43 changes: 42 additions & 1 deletion sqlalchemy_filters/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy.inspection import inspect

from .exceptions import BadQuery, FieldNotFound, BadSpec
Expand Down Expand Up @@ -35,7 +36,7 @@ def get_query_models(query):
}


def get_model_from_spec(spec, query):
def get_model_from_spec(spec, query, default_model=None):
""" Determine the model to which a spec applies on a given query.
A spec that does not specify a model may be applied to a query that
Expand Down Expand Up @@ -74,9 +75,49 @@ def get_model_from_spec(spec, query):
else:
if len(models) == 1:
model = list(models.values())[0]
elif default_model is not None:
return default_model
else:
raise BadSpec(
"Ambiguous spec. Please specify a model."
)

return model


def get_model_class_by_name(registry, name):
""" Return the model class matching `name` in the given `registry`.
"""
for cls in registry.values():
if getattr(cls, '__name__', None) == name:
return cls


def get_default_model(query):
""" Return the singular model from `query`, or `None` if `query` contains
multiple models.
"""
query_models = get_query_models(query).values()
if len(query_models) == 1:
default_model, = iter(query_models)
else:
default_model = None
return default_model


def auto_join(query, *model_names):
""" Automatically join models to `query` if they're not already present
and the join can be done implicitly.
"""
# every model has access to the registry, so we can use any from the query
query_models = get_query_models(query).values()
model_registry = list(query_models)[-1]._decl_class_registry

for name in model_names:
model = get_model_class_by_name(model_registry, name)
if model not in get_query_models(query).values():
try:
query = query.join(model)
except InvalidRequestError:
pass # can't be autojoined
return query
Loading

0 comments on commit 1e6408c

Please sign in to comment.