diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 8ca75cc..27b2e88 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,6 +4,14 @@ Release Notes Here you can see the full list of changes between sqlalchemy-filters versions, where semantic versioning is used: *major.minor.patch*. +Unreleased +---------- + +* Add support for hybrid attributes (properties and methods): filtering + and sorting (#45) as a continuation of the work started here (#32) + by @vkylamba + - Addresses (#22) + 0.11.0 ------ diff --git a/README.rst b/README.rst index cddaa78..984a057 100644 --- a/README.rst +++ b/README.rst @@ -36,6 +36,14 @@ Assuming that we have a SQLAlchemy_ ``query`` object: name = Column(String(50), nullable=False) count = Column(Integer, nullable=True) + @hybrid_property + def count_square(self): + return self.count * self.count + + @hybrid_method + def three_times_count(self): + return self.count * 3 + Base = declarative_base(cls=Base) @@ -137,6 +145,21 @@ It is also possible to apply filters to queries defined by fields, functions or query_alt_2 = session.query(func.count(Foo.id)) query_alt_3 = session.query().select_from(Foo).add_column(Foo.id) +Hybrid attributes +^^^^^^^^^^^^^^^^^ + +You can filter by a `hybrid attribute`_: a `hybrid property`_ or a `hybrid method`_. + +.. code-block:: python + + query = session.query(Foo) + + filter_spec = [{'field': 'count_square', 'op': '>=', 'value': 25}] + filter_spec = [{'field': 'three_times_count', 'op': '>=', 'value': 15}] + + filtered_query = apply_filters(query, filter_spec) + result = filtered_query.all() + Restricted Loads ---------------- @@ -241,6 +264,11 @@ The behaviour is the same as in ``apply_filters``. This allows flexibility for clients to sort by fields on related objects without specifying all possible joins on the server beforehand. +Hybrid attributes +^^^^^^^^^^^^^^^^^ + +You can sort by a `hybrid attribute`_: a `hybrid property`_ or a `hybrid method`_. + Pagination ---------- @@ -489,3 +517,6 @@ for details. .. _SQLAlchemy: https://www.sqlalchemy.org/ +.. _hybrid attribute: https://docs.sqlalchemy.org/en/13/orm/extensions/hybrid.html +.. _hybrid property: https://docs.sqlalchemy.org/en/13/orm/extensions/hybrid.html#sqlalchemy.ext.hybrid.hybrid_property +.. _hybrid method: https://docs.sqlalchemy.org/en/13/orm/extensions/hybrid.html#sqlalchemy.ext.hybrid.hybrid_method \ No newline at end of file diff --git a/sqlalchemy_filters/models.py b/sqlalchemy_filters/models.py index 4150f42..1c79516 100644 --- a/sqlalchemy_filters/models.py +++ b/sqlalchemy_filters/models.py @@ -1,6 +1,8 @@ from sqlalchemy.exc import InvalidRequestError from sqlalchemy.inspection import inspect from sqlalchemy.orm.mapper import Mapper +from sqlalchemy.util import symbol +import types from .exceptions import BadQuery, FieldNotFound, BadSpec @@ -12,13 +14,41 @@ def __init__(self, model, field_name): self.field_name = field_name def get_sqlalchemy_field(self): - if self.field_name not in inspect(self.model).columns.keys(): + if self.field_name not in self._get_valid_field_names(): raise FieldNotFound( 'Model {} has no column `{}`.'.format( self.model, self.field_name ) ) - return getattr(self.model, self.field_name) + sqlalchemy_field = getattr(self.model, self.field_name) + + # If it's a hybrid method, then we call it so that we can work with + # the result of the execution and not with the method object itself + if isinstance(sqlalchemy_field, types.MethodType): + sqlalchemy_field = sqlalchemy_field() + + return sqlalchemy_field + + def _get_valid_field_names(self): + inspect_mapper = inspect(self.model) + columns = inspect_mapper.columns + orm_descriptors = inspect_mapper.all_orm_descriptors + + column_names = columns.keys() + hybrid_names = [ + key for key, item in orm_descriptors.items() + if _is_hybrid_property(item) or _is_hybrid_method(item) + ] + + return set(column_names) | set(hybrid_names) + + +def _is_hybrid_property(orm_descriptor): + return orm_descriptor.extension_type == symbol('HYBRID_PROPERTY') + + +def _is_hybrid_method(orm_descriptor): + return orm_descriptor.extension_type == symbol('HYBRID_METHOD') def get_query_models(query): diff --git a/test/interface/test_filters.py b/test/interface/test_filters.py index ad0efd2..d904714 100644 --- a/test/interface/test_filters.py +++ b/test/interface/test_filters.py @@ -1215,3 +1215,104 @@ def test_not_any_values_in_array(self, session, is_postgresql): assert len(result) == 2 assert result[0].id == 1 assert result[1].id == 4 + + +class TestHybridAttributes: + + @pytest.mark.usefixtures('multiple_bars_inserted') + @pytest.mark.parametrize( + ('field, expected_error'), + [ + ('foos', "Model has no column `foos`."), + ( + '__mapper__', + "Model has no column `__mapper__`.", + ), + ( + 'not_valid', + "Model has no column `not_valid`.", + ), + ] + ) + def test_orm_descriptors_not_valid_hybrid_attributes( + self, session, field, expected_error + ): + query = session.query(Bar) + filters = [ + { + 'model': 'Bar', + 'field': field, + 'op': '==', + 'value': 100 + } + ] + with pytest.raises(FieldNotFound) as exc: + apply_filters(query, filters) + + assert expected_error in str(exc) + + @pytest.mark.usefixtures('multiple_bars_inserted') + @pytest.mark.usefixtures('multiple_quxs_inserted') + def test_filter_by_hybrid_properties(self, session): + query = session.query(Bar, Qux) + filters = [ + { + 'model': 'Bar', + 'field': 'count_square', + 'op': '==', + 'value': 100 + }, + { + 'model': 'Qux', + 'field': 'count_square', + 'op': '>=', + 'value': 26 + }, + ] + + filtered_query = apply_filters(query, filters) + result = filtered_query.all() + + assert len(result) == 2 + bars, quxs = zip(*result) + + assert set(map(type, bars)) == {Bar} + assert {bar.id for bar in bars} == {2} + assert {bar.count_square for bar in bars} == {100} + + assert set(map(type, quxs)) == {Qux} + assert {qux.id for qux in quxs} == {2, 4} + assert {qux.count_square for qux in quxs} == {100, 225} + + @pytest.mark.usefixtures('multiple_bars_inserted') + @pytest.mark.usefixtures('multiple_quxs_inserted') + def test_filter_by_hybrid_methods(self, session): + query = session.query(Bar, Qux) + filters = [ + { + 'model': 'Bar', + 'field': 'three_times_count', + 'op': '==', + 'value': 30 + }, + { + 'model': 'Qux', + 'field': 'three_times_count', + 'op': '>=', + 'value': 31 + }, + ] + + filtered_query = apply_filters(query, filters) + result = filtered_query.all() + + assert len(result) == 1 + bars, quxs = zip(*result) + + assert set(map(type, bars)) == {Bar} + assert {bar.id for bar in bars} == {2} + assert {bar.three_times_count() for bar in bars} == {30} + + assert set(map(type, quxs)) == {Qux} + assert {qux.id for qux in quxs} == {4} + assert {qux.three_times_count() for qux in quxs} == {45} diff --git a/test/interface/test_sorting.py b/test/interface/test_sorting.py index 3111fd0..a2fd821 100644 --- a/test/interface/test_sorting.py +++ b/test/interface/test_sorting.py @@ -571,3 +571,68 @@ def test_multiple_sort_fields_desc_nulls_last( ('name_4', None), ('name_5', 50), ] + + +class TestSortHybridAttributes(object): + + """Tests that results are sorted only according to the provided + filters. + + Does NOT test how rows with the same values are sorted since this is + not consistent across RDBMS. + + Does NOT test whether `NULL` field values are placed first or last + when sorting since this may differ across RDBMSs. + + SQL defines that `NULL` values should be placed together when + sorting, but it does not specify whether they should be placed first + or last. + """ + + @pytest.mark.usefixtures('multiple_bars_with_no_nulls_inserted') + def test_single_sort_hybrid_property_asc(self, session): + query = session.query(Bar) + order_by = [{'field': 'count_square', 'direction': 'asc'}] + + sorted_query = apply_sort(query, order_by) + results = sorted_query.all() + + assert [result.count_square for result in results] == [ + 1, 4, 4, 9, 25, 100, 144, 225 + ] + + @pytest.mark.usefixtures('multiple_bars_with_no_nulls_inserted') + def test_single_sort_hybrid_property_desc(self, session): + query = session.query(Bar) + order_by = [{'field': 'count_square', 'direction': 'desc'}] + + sorted_query = apply_sort(query, order_by) + results = sorted_query.all() + + assert [result.count_square for result in results] == [ + 225, 144, 100, 25, 9, 4, 4, 1 + ] + + @pytest.mark.usefixtures('multiple_bars_with_no_nulls_inserted') + def test_single_sort_hybrid_method_asc(self, session): + query = session.query(Bar) + order_by = [{'field': 'three_times_count', 'direction': 'asc'}] + + sorted_query = apply_sort(query, order_by) + results = sorted_query.all() + + assert [result.three_times_count() for result in results] == [ + 3, 6, 6, 9, 15, 30, 36, 45 + ] + + @pytest.mark.usefixtures('multiple_bars_with_no_nulls_inserted') + def test_single_sort_hybrid_method_desc(self, session): + query = session.query(Bar) + order_by = [{'field': 'three_times_count', 'direction': 'desc'}] + + sorted_query = apply_sort(query, order_by) + results = sorted_query.all() + + assert [result.three_times_count() for result in results] == [ + 45, 36, 30, 15, 9, 6, 6, 3 + ] diff --git a/test/models.py b/test/models.py index 10ecd62..6484c3f 100644 --- a/test/models.py +++ b/test/models.py @@ -5,6 +5,7 @@ ) from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.ext.hybrid import hybrid_property, hybrid_method from sqlalchemy.orm import relationship @@ -13,6 +14,14 @@ class Base(object): name = Column(String(50), nullable=False) count = Column(Integer, nullable=True) + @hybrid_property + def count_square(self): + return self.count * self.count + + @hybrid_method + def three_times_count(self): + return self.count * 3 + Base = declarative_base(cls=Base) BasePostgresqlSpecific = declarative_base(cls=Base)