Skip to content

Commit

Permalink
Add mor and __gor__ filters (#21)
Browse files Browse the repository at this point in the history
* Adding mor and gor for or clause

* Fix and optimize filters

* Add usage documents
  • Loading branch information
wu-clan authored Aug 25, 2024
1 parent 43800c1 commit eb54da9
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 36 deletions.
47 changes: 41 additions & 6 deletions docs/advanced/filter.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,12 @@ items = await item_crud.select_models(

运算符需要多个值,且仅允许元组,列表,集合

```python
# 获取年龄在 30 - 40 岁之间的员工
```python title="__between"
# 获取年龄在 30 - 40 岁之间且名字在目标列表的员工
items = await item_crud.select_models(
session=db,
age__between=[30, 40],
name__in=['bob', 'lucy'],
)
```

Expand All @@ -86,7 +87,7 @@ items = await item_crud.select_models(
可以通过将多个过滤器链接在一起来实现 AND 子句

```python
# 获取年龄在 30 以上,薪资大于 2w 的员工
# 获取年龄在 30 以上,薪资大于 20k 的员工
items = await item_crud.select_models(
session=db,
age__gt=30,
Expand All @@ -100,14 +101,48 @@ items = await item_crud.select_models(

每个键都应是库已支持的过滤器,仅允许字典

```python
```python title="__or"
# 获取年龄在 40 岁以上或 30 岁以下的员工
items = await item_crud.select_models(
session=db,
age__or={'gt': 40, 'lt': 30},
)
```

## MOR

!!! note

`or` 过滤器的高级用法,每个键都应是库已支持的过滤器,仅允许字典

```python title="__mor"
# 获取年龄等于 30 岁和 40 岁的员工
items = await item_crud.select_models(
session=db,
age__mor={'eq': [30, 40]}, # (1)
)
```

1. 原因:在 python 字典中,不允许存在相同的键值;<br/>
场景:我有一个列,需要多个相同条件但不同条件值的查询,此时,你应该使用 `mor` 过滤器,正如此示例一样使用它

## GOR

!!! note

`or` 过滤器的更高级用法,每个值都应是一个已受支持的条件过滤器,它应该是一个数组

```python title="__gor__"
# 获取年龄在 30 - 40 岁之间且薪资大于 20k 的员工
items = await item_crud.select_models(
session=db,
__gor__=[
{'age__between': [30, 40]},
{'payroll__gt': 20000}
]
)
```

## 算数

!!! note
Expand All @@ -119,9 +154,9 @@ items = await item_crud.select_models(
`condition`:此值将作为运算后的比较值,比较条件取决于使用的过滤器

```python
# 获取薪资打八折以后仍高于 15000 的员工
# 获取薪资打八折以后仍高于 20k 的员工
items = await item_crud.select_models(
session=db,
payroll__mul={'value': 0.8, 'condition': {'gt': 15000}},
payroll__mul={'value': 0.8, 'condition': {'gt': 20000}},
)
```
103 changes: 74 additions & 29 deletions sqlalchemy_crud_plus/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from sqlalchemy import ColumnElement, Select, and_, asc, desc, func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import InstrumentedAttribute
from sqlalchemy.orm.util import AliasedClass

from sqlalchemy_crud_plus.errors import ColumnSortError, ModelColumnError, SelectOperatorError
Expand Down Expand Up @@ -70,7 +71,7 @@ def get_sqlalchemy_filter(operator: str, value: Any, allow_arithmetic: bool = Tr
raise SelectOperatorError(f'Nested arithmetic operations are not allowed: {operator}')

sqlalchemy_filter = _SUPPORTED_FILTERS.get(operator)
if sqlalchemy_filter is None:
if sqlalchemy_filter is None and operator not in ['or', 'mor', '__gor']:
warnings.warn(
f'The operator <{operator}> is not yet supported, only {", ".join(_SUPPORTED_FILTERS.keys())}.',
SyntaxWarning,
Expand All @@ -80,48 +81,92 @@ def get_sqlalchemy_filter(operator: str, value: Any, allow_arithmetic: bool = Tr
return sqlalchemy_filter


def get_column(model: Type[Model] | AliasedClass, field_name: str):
def get_column(model: Type[Model] | AliasedClass, field_name: str) -> InstrumentedAttribute | None:
column = getattr(model, field_name, None)
if column is None:
raise ModelColumnError(f'Column {field_name} is not found in {model}')
return column


def _create_or_filters(column: str, op: str, value: Any) -> list[ColumnElement | None]:
or_filters = []
if op == 'or':
for or_op, or_value in value.items():
sqlalchemy_filter = get_sqlalchemy_filter(or_op, or_value)
if sqlalchemy_filter is not None:
or_filters.append(sqlalchemy_filter(column)(or_value))
elif op == 'mor':
for or_op, or_values in value.items():
for or_value in or_values:
sqlalchemy_filter = get_sqlalchemy_filter(or_op, or_value)
if sqlalchemy_filter is not None:
or_filters.append(sqlalchemy_filter(column)(or_value))
return or_filters


def _create_arithmetic_filters(column: str, op: str, value: Any) -> list[ColumnElement | None]:
arithmetic_filters = []
if isinstance(value, dict) and {'value', 'condition'}.issubset(value):
arithmetic_value = value['value']
condition = value['condition']
sqlalchemy_filter = get_sqlalchemy_filter(op, arithmetic_value)
if sqlalchemy_filter is not None:
for cond_op, cond_value in condition.items():
arithmetic_filter = get_sqlalchemy_filter(cond_op, cond_value, allow_arithmetic=False)
arithmetic_filters.append(
arithmetic_filter(sqlalchemy_filter(column)(arithmetic_value))(cond_value)
if cond_op != 'between'
else arithmetic_filter(sqlalchemy_filter(column)(arithmetic_value))(*cond_value)
)
return arithmetic_filters


def _create_and_filters(column: str, op: str, value: Any) -> list[ColumnElement | None]:
and_filters = []
sqlalchemy_filter = get_sqlalchemy_filter(op, value)
if sqlalchemy_filter is not None:
and_filters.append(sqlalchemy_filter(column)(value) if op != 'between' else sqlalchemy_filter(column)(*value))
return and_filters


def parse_filters(model: Type[Model] | AliasedClass, **kwargs) -> list[ColumnElement]:
filters = []

def process_filters(target_column: str, target_op: str, target_value: Any):
# OR / MOR
or_filters = _create_or_filters(target_column, target_op, target_value)
if or_filters:
filters.append(or_(*or_filters))

# ARITHMETIC
arithmetic_filters = _create_arithmetic_filters(target_column, target_op, target_value)
if arithmetic_filters:
filters.append(and_(*arithmetic_filters))
else:
# AND
and_filters = _create_and_filters(target_column, target_op, target_value)
if and_filters:
filters.append(*and_filters)

for key, value in kwargs.items():
if '__' in key:
field_name, op = key.rsplit('__', 1)
column = get_column(model, field_name)
if op == 'or':
or_filters = [
sqlalchemy_filter(column)(or_value)
for or_op, or_value in value.items()
if (sqlalchemy_filter := get_sqlalchemy_filter(or_op, or_value)) is not None
]
filters.append(or_(*or_filters))
elif isinstance(value, dict) and {'value', 'condition'}.issubset(value):
advanced_value = value['value']
condition = value['condition']
sqlalchemy_filter = get_sqlalchemy_filter(op, advanced_value)
if sqlalchemy_filter is not None:
condition_filters = []
for cond_op, cond_value in condition.items():
condition_filter = get_sqlalchemy_filter(cond_op, cond_value, allow_arithmetic=False)
condition_filters.append(
condition_filter(sqlalchemy_filter(column)(advanced_value))(cond_value)
if cond_op != 'between'
else condition_filter(sqlalchemy_filter(column)(advanced_value))(*cond_value)
)
filters.append(and_(*condition_filters))

# OR GROUP
if field_name == '__gor' and op == '':
_or_filters = []
for field_or in value:
for _key, _value in field_or.items():
_field_name, _op = _key.rsplit('__', 1)
_column = get_column(model, _field_name)
process_filters(_column, _op, _value)
if _or_filters:
filters.append(or_(*_or_filters))
else:
sqlalchemy_filter = get_sqlalchemy_filter(op, value)
if sqlalchemy_filter is not None:
filters.append(
sqlalchemy_filter(column)(value) if op != 'between' else sqlalchemy_filter(column)(*value)
)
column = get_column(model, field_name)
process_filters(column, op, value)
else:
# NON FILTER
column = get_column(model, key)
filters.append(column == value)

Expand Down
25 changes: 24 additions & 1 deletion tests/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ async def test_select_model_by_column_with_ne(create_test_model, async_db_sessio
async def test_select_model_by_column_with_between(create_test_model, async_db_session):
async with async_db_session() as session:
crud = CRUDPlus(Ins)
result = await crud.select_model_by_column(session, id__between=(0, 11))
result = await crud.select_model_by_column(session, id__between=(0, 10))
assert result.id == 1


Expand Down Expand Up @@ -338,6 +338,29 @@ async def test_select_model_by_column_with_or(create_test_model, async_db_sessio
assert result.id == 1


@pytest.mark.asyncio
async def test_select_model_by_column_with_mor(create_test_model, async_db_session):
async with async_db_session() as session:
crud = CRUDPlus(Ins)
result = await crud.select_model_by_column(session, id__mor={'eq': [1, 2, 3, 4, 5, 6, 7, 8, 9]})
assert result.id == 1


@pytest.mark.asyncio
async def test_select_model_by_column_with___gor__(create_test_model, async_db_session):
async with async_db_session() as session:
crud = CRUDPlus(Ins)
result = await crud.select_model_by_column(
session,
__gor__=[
{'id__eq': 1},
{'name__mor': {'endswith': ['1', '2']}},
{'id__mul': {'value': 1, 'condition': {'eq': 1}}},
],
)
assert result.id == 1


@pytest.mark.asyncio
async def test_select(create_test_model):
crud = CRUDPlus(Ins)
Expand Down

0 comments on commit eb54da9

Please sign in to comment.