Skip to content

Commit

Permalink
Merge pull request #12 from Overseas-Student-Living/multiple-models
Browse files Browse the repository at this point in the history
Multiple models
  • Loading branch information
mattbennett authored Nov 15, 2017
2 parents c8786ea + 09e022d commit 54b1f9f
Show file tree
Hide file tree
Showing 10 changed files with 313 additions and 124 deletions.
21 changes: 16 additions & 5 deletions .travis.yml
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
sudo: false
language: python

python:
- '2.7'
services:
- mysql

install:
- pip install tox

env:
- TOX_ENV=py33
- TOX_ENV=py34
matrix:
include:
- stage: test
python: 3.3
env: TOX_ENV=py33
- stage: test
python: 3.4
env: TOX_ENV=py34
- stage: test
python: 3.5
env: TOX_ENV=py35
- stage: test
python: 3.6
env: TOX_ENV=py36

script:
- tox -e $TOX_ENV
52 changes: 40 additions & 12 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ SQLAlchemy-filters
Filtering
---------

Assuming that we have a SQLAlchemy `query` that only contains a single
model:
Assuming that we have a SQLAlchemy `query` object:

.. code-block:: python
Expand All @@ -33,7 +32,7 @@ model:
# ...
query = self.session.query(Foo)
query = session.query(Foo)
Then we can apply filters to that ``query`` object (multiple times):

Expand All @@ -51,12 +50,36 @@ Then we can apply filters to that ``query`` object (multiple times):
result = filtered_query.all()
It is also possible to filter queries that contain multiple models, including joins:

.. code-block:: python
class Bar(Base):
__tablename__ = 'bar'
foo_id = Column(Integer, ForeignKey('foo.id'))
.. code-block:: python
query = session.query(Foo).join(Bar)
filters = [
{'model': 'Foo', field': 'name', 'op': '==', 'value': 'name_1'},
{'model': 'Bar', field': 'count', 'op': '>=', 'value': 5},
]
filtered_query = apply_filters(query, filters)
result = filtered_query.all()
You must specify the `model` key in each filter if the query is against more than one model.
Note that we can also apply filters to queries defined by fields or functions:
.. code-block:: python
query_alt_1 = self.session.query(Foo.id, Foo.name)
query_alt_2 = self.session.query(func.count(Foo.id))
query_alt_1 = session.query(Foo.id, Foo.name)
query_alt_2 = session.query(func.count(Foo.id))
Sort
Expand All @@ -69,8 +92,8 @@ Sort
# `query` should be a SQLAlchemy query object
order_by = [
{'field': 'name', 'direction': 'asc'},
{'field': 'id', 'direction': 'desc'},
{'model': 'Foo', field': 'name', 'direction': 'asc'},
{'model': 'Bar', field': 'id', 'direction': 'desc'},
]
sorted_query = apply_sort(query, order_by)
Expand Down Expand Up @@ -106,12 +129,14 @@ following format:
.. code-block:: python
filters = [
{'field': 'field_name', 'op': '==', 'value': 'field_value'},
{'field': 'field_2_name', 'op': '!=', 'value': 'field_2_value'},
{'model': 'model_name', 'field': 'field_name', 'op': '==', 'value': 'field_value'},
{{'model': 'model_name', 'field': 'field_2_name', 'op': '!=', 'value': 'field_2_value'},
# ...
]
Optionally, if there is only one filter, the containing list may be omitted:
The `model` key is optional if the query being filtered only applies to one model.
If there is only one filter, the containing list may be omitted:
.. code-block:: python
Expand Down Expand Up @@ -171,14 +196,17 @@ applied sequentially:
.. code-block:: python
order_by = [
{'field': 'name', 'direction': 'asc'},
{'field': 'id', 'direction': 'desc'},
{'model': 'Foo', 'field': 'name', 'direction': 'asc'},
{'model': 'Bar', field': 'id', 'direction': 'desc'},
# ...
]
Where ``field`` is the name of the field that will be sorted using the
provided ``direction``.
The `model` key is optional if the query being sorted only applies to one model.
Running tests
-------------
Expand Down
10 changes: 3 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,20 @@
'sqlalchemy-utils==0.32.12',
],
'mysql': [
'mysql-connector-python==2.1.5',
'mysql-connector-python-rf==2.1.3',
]
},
dependency_links=[
'https://cdn.mysql.com/Downloads/Connector-Python'
'/mysql-connector-python-2.1.5.zip'
],
zip_safe=True,
license='Apache License, Version 2.0',
classifiers=[
"Programming Language :: Python",
"Operating System :: POSIX",
"Operating System :: MacOS :: MacOS X",
"Programming Language :: Python :: 2",
"Programming Language :: Python :: 2.7",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.3",
"Programming Language :: Python :: 3.4",
"Programming Language :: Python :: 3.5",
"Programming Language :: Python :: 3.6",
"Topic :: Internet",
"Topic :: Software Development :: Libraries :: Python Modules",
"Intended Audience :: Developers",
Expand Down
18 changes: 17 additions & 1 deletion sqlalchemy_filters/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,23 @@ def __init__(self, filter_, models):
'Filter `{}` should be a dictionary.'.format(filter_)
)

self.field = Field(models, field_name)
model_name = filter_.get('model')
if model_name is not None:
models = [v for (k, v) in models.items() if k == model_name]
if not models:
raise BadFilterFormat(
'The query does not contain model `{}`.'.format(model_name)
)
model = models[0]
else:
if len(models) == 1:
model = list(models.values())[0]
else:
raise BadFilterFormat(
"Ambiguous filter. Please specify a model."
)

self.field = Field(model, field_name)
self.operator = Operator(filter_.get('op'))
self.value = filter_.get('value')
self.value_present = True if 'value' in filter_ else False
Expand Down
19 changes: 6 additions & 13 deletions sqlalchemy_filters/models.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,14 @@
from sqlalchemy.inspection import inspect

from .exceptions import FieldNotFound, BadQuery
from .exceptions import FieldNotFound


class Field(object):

def __init__(self, models, field_name):
# TODO: remove this check once we start supporing multiple models
if len(models) > 1:
raise BadQuery('The query should contain only one model.')

self.model = self._get_model(models)
def __init__(self, model, field_name):
self.model = model
self.field_name = field_name

def _get_model(self, models):
# TODO: add model_name argument once we start supporing multiple models
return [v for (k, v) in models.items()][0] # first (and only) model

def get_sqlalchemy_field(self):
if self.field_name not in inspect(self.model).columns.keys():
raise FieldNotFound(
Expand All @@ -36,7 +28,8 @@ def get_query_models(query):
:returns:
A dictionary with all the models included in the query.
"""
models = [col_desc['entity'] for col_desc in query.column_descriptions]
models.extend(mapper.class_ for mapper in query._join_entities)
return {
col_desc['entity'].__name__: col_desc['entity']
for col_desc in query.column_descriptions
model.__name__: model for model in models
}
18 changes: 17 additions & 1 deletion sqlalchemy_filters/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,23 @@ def __init__(self, sort, models):
if direction not in [SORT_ASCENDING, SORT_DESCENDING]:
raise BadSortFormat('Direction `{}` not valid.'.format(direction))

self.field = Field(models, field_name)
model_name = sort.get('model')
if model_name is not None:
models = [v for (k, v) in models.items() if k == model_name]
if not models:
raise BadSortFormat(
'The query does not contain model `{}`.'.format(model_name)
)
model = models[0]
else:
if len(models) == 1:
model = list(models.values())[0]
else:
raise BadSortFormat(
"Ambiguous sort. Please specify a model.".format()
)

self.field = Field(model, field_name)
self.direction = direction

def format_for_sqlalchemy(self):
Expand Down
Loading

0 comments on commit 54b1f9f

Please sign in to comment.