Skip to content

Commit

Permalink
Merge pull request #7 from Overseas-Student-Living/sorting
Browse files Browse the repository at this point in the history
Add apply sort functionality
  • Loading branch information
juliotrigo authored Jan 6, 2017
2 parents 21c6cc9 + eb21f5a commit 58b74c9
Show file tree
Hide file tree
Showing 15 changed files with 401 additions and 86 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ install:
- pip install tox

env:
- TOX_ENV=py33
- TOX_ENV=py34

script:
Expand Down
26 changes: 26 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
Release Notes
=============

Here you can see the full list of changes between sqlalchemy-filters
versions, where semantic versioning is used: *major.minor.patch*.

Backwards-compatible changes increment the minor version number only.

Version 0.2.0
-------------

Released 2017-01-06

* Adds apply query pagination
* Adds apply query sort
* Adds Travis CI
* Starts using Tox
* Refactors Makefile and conftest

Version 0.1.0
-------------

Released 2016-09-08

* Initial version
* Adds apply query filters
35 changes: 35 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,24 @@ Then we can apply filters to that ``query`` object (multiple times):
result = filtered_query.all()
Sort
----

.. code-block:: python
from sqlalchemy_filters import apply_sort
# `query` should be a SQLAlchemy query object
order_by = [
{'field': 'name', 'direction': 'asc'},
{'field': 'id', 'direction': 'desc'},
]
sorted_query = apply_sort(query, order_by)
result = sorted_query.all()
Pagination
----------

Expand Down Expand Up @@ -103,6 +121,23 @@ This is the list of operators that can be used:
- ``in``
- ``not_in``

Sort format
-----------

Sort elements must be provided as dictionaries in a list and will be
applied sequentially:

.. code-block:: python
order_by = [
{'field': 'name', 'direction': 'asc'},
{'field': 'id', 'direction': 'desc'},
# ...
]
Where ``field`` is the name of the field that will be sorted using the
provided ``direction``.

Running tests
-------------

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

setup(
name='sqlalchemy-filters',
version='0.1.0',
version='0.2.0',
description='A library to filter SQLAlchemy queries.',
long_description=readme,
author='Student.com',
Expand Down
3 changes: 2 additions & 1 deletion sqlalchemy_filters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-

from .filters import apply_filters, get_query_models # noqa: F401
from .filters import apply_filters # noqa: F401
from .models import get_query_models # noqa: F401
from .pagination import apply_pagination # noqa: F401
from .sorting import apply_sort # noqa: F401
8 changes: 8 additions & 0 deletions sqlalchemy_filters/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@ class BadFilterFormat(Exception):
pass


class BadSortFormat(Exception):
pass


class FieldNotFound(Exception):
pass


class BadQuery(Exception):
pass

Expand Down
42 changes: 1 addition & 41 deletions sqlalchemy_filters/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@

from inspect import signature

from sqlalchemy.inspection import inspect

from .exceptions import BadFilterFormat, BadQuery
from .models import Field, get_query_models


class Operator(object):
Expand Down Expand Up @@ -38,30 +37,6 @@ def __init__(self, operator):
self.arity = len(signature(self.function).parameters)


class Field(object):

def __init__(self, models, field_name):
# TODO: remove this check once we start supporing multiple models
if len(models) > 1:
raise BadQuery('The query should contain only one model.')

self.model = self._get_model(models)
self.field_name = field_name

def _get_model(self, models):
# TODO: add model_name argument once we start supporing multiple models
return [v for (k, v) in models.items()][0] # first (and only) model

def get_sqlalchemy_field(self):
if self.field_name not in inspect(self.model).columns.keys():
raise BadFilterFormat(
'Model {} has no column `{}`.'.format(
self.model, self.field_name
)
)
return getattr(self.model, self.field_name)


class Filter(object):

def __init__(self, filter_, models):
Expand Down Expand Up @@ -123,18 +98,3 @@ def apply_filters(query, filters):
query = query.filter(*sqlalchemy_filters)

return query


def get_query_models(query):
"""Get models from query.
:param query:
A :class:`sqlalchemy.orm.Query` instance.
:returns:
A dictionary with all the models included in the query.
"""
return {
entity['type'].__name__: entity['type']
for entity in query.column_descriptions
}
42 changes: 42 additions & 0 deletions sqlalchemy_filters/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from sqlalchemy.inspection import inspect

from .exceptions import FieldNotFound, BadQuery


class Field(object):

def __init__(self, models, field_name):
# TODO: remove this check once we start supporing multiple models
if len(models) > 1:
raise BadQuery('The query should contain only one model.')

self.model = self._get_model(models)
self.field_name = field_name

def _get_model(self, models):
# TODO: add model_name argument once we start supporing multiple models
return [v for (k, v) in models.items()][0] # first (and only) model

def get_sqlalchemy_field(self):
if self.field_name not in inspect(self.model).columns.keys():
raise FieldNotFound(
'Model {} has no column `{}`.'.format(
self.model, self.field_name
)
)
return getattr(self.model, self.field_name)


def get_query_models(query):
"""Get models from query.
:param query:
A :class:`sqlalchemy.orm.Query` instance.
:returns:
A dictionary with all the models included in the query.
"""
return {
entity['type'].__name__: entity['type']
for entity in query.column_descriptions
}
69 changes: 66 additions & 3 deletions sqlalchemy_filters/sorting.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,69 @@
# -*- coding: utf-8 -*-

from .exceptions import BadQuery, BadSortFormat
from .models import Field, get_query_models

def apply_sort(query, sort): # pragma: no cover
# TODO
raise NotImplemented()

SORT_ASCENDING = 'asc'
SORT_DESCENDING = 'desc'


class Sort(object):

def __init__(self, sort, models):
try:
field_name = sort['field']
direction = sort['direction']
except KeyError:
raise BadSortFormat(
'`field` and `direction` are mandatory attributes.'
)
except TypeError:
raise BadSortFormat(
'Sort `{}` should be a dictionary.'.format(sort)
)

if direction not in [SORT_ASCENDING, SORT_DESCENDING]:
raise BadSortFormat('Direction `{}` not valid.'.format(direction))

self.field = Field(models, field_name)
self.direction = direction

def format_for_sqlalchemy(self):
field = self.field.get_sqlalchemy_field()

if self.direction == SORT_ASCENDING:
return field.asc()
elif self.direction == SORT_DESCENDING:
return field.desc()


def apply_sort(query, order_by):
"""Apply sorting to a :class:`sqlalchemy.orm.Query` instance.
:param order_by:
A list of dictionaries, where each one of them includes
the necesary information to order the elements of the query.
Example::
order_by = [
{'field': 'name', 'direction': 'asc'},
{'field': 'id', 'direction': 'desc'},
]
:returns:
The :class:`sqlalchemy.orm.Query` instance after the provided
sorting has been applied.
"""
models = get_query_models(query)
if not models:
raise BadQuery('The query does not contain any models.')

sqlalchemy_order_by = [
Sort(sort, models).format_for_sqlalchemy() for sort in order_by
]
if sqlalchemy_order_by:
query = query.order_by(*sqlalchemy_order_by)

return query
5 changes: 5 additions & 0 deletions test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# -*- coding: utf-8 -*-


def error_value(exception):
return exception.value.args[0]
41 changes: 6 additions & 35 deletions test/interface/test_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,42 +3,13 @@
import datetime

import pytest
from sqlalchemy_filters import apply_filters, get_query_models
from sqlalchemy_filters.exceptions import BadFilterFormat, BadQuery
from sqlalchemy_filters import apply_filters
from sqlalchemy_filters.exceptions import (
BadFilterFormat, FieldNotFound, BadQuery
)
from test.models import Bar, Qux


class TestGetQueryModels(object):

def test_query_with_no_models(self, session):
query = session.query()

entities = get_query_models(query)

assert {} == entities

def test_query_with_one_model(self, session):
query = session.query(Bar)

entities = get_query_models(query)

assert {'Bar': Bar} == entities

def test_query_with_multiple_models(self, session):
query = session.query(Bar, Qux)

entities = get_query_models(query)

assert {'Bar': Bar, 'Qux': Qux} == entities

def test_query_with_duplicated_models(self, session):
query = session.query(Bar, Qux, Bar)

entities = get_query_models(query)

assert {'Bar': Bar, 'Qux': Qux} == entities


class TestProvidedModels(object):

def test_query_with_no_models(self, session):
Expand Down Expand Up @@ -131,7 +102,7 @@ def test_invalid_field(self, session):
query = session.query(Bar)
filters = [{'field': 'invalid_field', 'op': '==', 'value': 'name_1'}]

with pytest.raises(BadFilterFormat) as err:
with pytest.raises(FieldNotFound) as err:
apply_filters(query, filters)

expected_error = (
Expand All @@ -147,7 +118,7 @@ def test_invalid_field_but_valid_model_attribute(self, session, attr_name):
query = session.query(Bar)
filters = [{'field': attr_name, 'op': '==', 'value': 'name_1'}]

with pytest.raises(BadFilterFormat) as err:
with pytest.raises(FieldNotFound) as err:
apply_filters(query, filters)

expected_error = (
Expand Down
33 changes: 33 additions & 0 deletions test/interface/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from sqlalchemy_filters import get_query_models
from test.models import Bar, Qux


class TestGetQueryModels(object):

def test_query_with_no_models(self, session):
query = session.query()

entities = get_query_models(query)

assert {} == entities

def test_query_with_one_model(self, session):
query = session.query(Bar)

entities = get_query_models(query)

assert {'Bar': Bar} == entities

def test_query_with_multiple_models(self, session):
query = session.query(Bar, Qux)

entities = get_query_models(query)

assert {'Bar': Bar, 'Qux': Qux} == entities

def test_query_with_duplicated_models(self, session):
query = session.query(Bar, Qux, Bar)

entities = get_query_models(query)

assert {'Bar': Bar, 'Qux': Qux} == entities
Loading

0 comments on commit 58b74c9

Please sign in to comment.