diff --git a/sqlalchemy_filters/filters.py b/sqlalchemy_filters/filters.py index 356c4fd..bf9d2b6 100644 --- a/sqlalchemy_filters/filters.py +++ b/sqlalchemy_filters/filters.py @@ -58,6 +58,9 @@ class Operator(object): 'not_in': lambda f, a: ~f.in_(a), 'any': lambda f, a: f.any(a), 'not_any': lambda f, a: func.not_(f.any(a)), + 'json_contains': lambda f, a: func.json_contains( + f, func.json_array(a) + ), } def __init__(self, operator=None): diff --git a/test/conftest.py b/test/conftest.py index ffe3dc7..33fbce2 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -108,6 +108,15 @@ def is_sqlite(db_uri): return False +@pytest.fixture(scope='session') +def is_sqlalchemy_1_3_or_higer(): + import sqlalchemy + if sqlalchemy.__version__ >= '1.3.0': + return True + else: + return False + + @pytest.fixture(scope='session') def db_engine_options(db_uri, is_postgresql): if is_postgresql: diff --git a/test/interface/test_filters.py b/test/interface/test_filters.py index d904714..b693f78 100644 --- a/test/interface/test_filters.py +++ b/test/interface/test_filters.py @@ -12,7 +12,7 @@ BadFilterFormat, BadSpec, FieldNotFound ) -from test.models import Foo, Bar, Qux, Corge +from test.models import Foo, Bar, Qux, Corge, Til ARRAY_NOT_SUPPORTED = ( @@ -85,6 +85,18 @@ def multiple_corges_inserted(session, is_postgresql): session.commit() +@pytest.fixture +def multiple_tils_inserted(session, is_sqlalchemy_1_3_or_higer): + if is_sqlalchemy_1_3_or_higer: + + til_1 = Til(id=1, name='name_1', refer_info=[]) + til_2 = Til(id=2, name='name_2', refer_info=[1]) + til_3 = Til(id=3, name='name_3', refer_info=[2, 3]) + til_4 = Til(id=4, name='name_4', refer_info=['foo', 'baz']) + session.add_all([til_1, til_2, til_3, til_4]) + session.commit() + + class TestFiltersNotApplied: def test_no_filters_provided(self, session): @@ -1316,3 +1328,77 @@ def test_filter_by_hybrid_methods(self, session): assert set(map(type, quxs)) == {Qux} assert {qux.id for qux in quxs} == {4} assert {qux.three_times_count() for qux in quxs} == {45} + + +class TestApplyJsonContainsFilter: + + @pytest.mark.usefixtures('multiple_tils_inserted') + def test_til_not_contains_value( + self, session, is_sqlite, is_postgresql, is_sqlalchemy_1_3_or_higer + ): + if is_sqlite: + pytest.skip() + + if is_postgresql: + pytest.skip() + + if not is_sqlalchemy_1_3_or_higer: + pytest.skip() + + query = session.query(Til) + filters = [ + {'field': 'refer_info', 'op': 'json_contains', 'value': 'invalid'} + ] + + filtered_query = apply_filters(query, filters) + result = filtered_query.all() + + assert len(result) == 0 + + @pytest.mark.usefixtures('multiple_tils_inserted') + def test_til_contains_int_value( + self, session, is_sqlite, is_postgresql, is_sqlalchemy_1_3_or_higer + ): + if is_sqlite: + pytest.skip() + + if is_postgresql: + pytest.skip() + + if not is_sqlalchemy_1_3_or_higer: + pytest.skip() + + query = session.query(Til) + filters = [ + {'field': 'refer_info', 'op': 'json_contains', 'value': 1} + ] + + filtered_query = apply_filters(query, filters) + result = filtered_query.all() + + assert len(result) == 1 + assert result[0].id == 2 + + @pytest.mark.usefixtures('multiple_tils_inserted') + def test_til_contains_str_value( + self, session, is_sqlite, is_postgresql, is_sqlalchemy_1_3_or_higer + ): + if is_sqlite: + pytest.skip() + + if is_postgresql: + pytest.skip() + + if not is_sqlalchemy_1_3_or_higer: + pytest.skip() + + query = session.query(Til) + filters = [ + {'field': 'refer_info', 'op': 'json_contains', 'value': "foo"} + ] + + filtered_query = apply_filters(query, filters) + result = filtered_query.all() + + assert len(result) == 1 + assert result[0].id == 4 diff --git a/test/models.py b/test/models.py index 6484c3f..8038194 100644 --- a/test/models.py +++ b/test/models.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- +import sqlalchemy from sqlalchemy import ( Column, Date, DateTime, ForeignKey, Integer, String, Time ) @@ -62,3 +63,16 @@ class Corge(BasePostgresqlSpecific): __tablename__ = 'corge' tags = Column(ARRAY(String, dimensions=1)) + + +if sqlalchemy.__version__ >= '1.3.0': + + from sqlalchemy import JSON + + class Til(Base): + + __tablename__ = 'til' + + refer_info = Column(JSON) +else: + Til = None