Skip to content

Commit

Permalink
Changes:
Browse files Browse the repository at this point in the history
- fix typings
  • Loading branch information
devkral committed Oct 24, 2024
1 parent 86a0c57 commit ec48eff
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 107 deletions.
7 changes: 5 additions & 2 deletions docs/queries/queries.md
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ users = await User.query.filter(and_.from_kwargs(User, name="foo", email="foo@ex
users = await User.query.filter(and_.from_kwargs(User, **my_dict))
```

#### Global OR
#### OR

Edgy QuerySet can do global ORs. This means you can attach new OR clauses also later.

Expand Down Expand Up @@ -317,7 +317,7 @@ user_query = user_query.or_(user_query, {"email": "outlook"}, {"email": "gmail"}
users = await user_query
```

#### Passing multiple keyword based filters
##### Passing multiple keyword based filters

You can also passing multiple keyword based filters by providing them as a dictionary

Expand All @@ -326,7 +326,10 @@ user_query = User.query.or_({"active": True}, {"email": "outlook"}, {"email": "g
# active users or users with email gmail or outlook are retrieved
users = await user_query
```
##### Local only OR

If the special mode of or_ is not wanted there is a function named `local_or`. It is similar
to the or_ function except it doesn't have the global OR mode.

### Limit

Expand Down
10 changes: 6 additions & 4 deletions docs/release-notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,22 @@ hide:

# Release Notes

## Unreleased
## 0.20.0

### Added

- Add DurationField.
- Allow passing `max_digits` to FloatField.
- Add `local_or` function to QuerySets.

### Changed

- Only the main table of a queryset is queryable via `model_class.columns.foo == foo`. Joins are now unique.
- Only the main table of a queryset is queryable via `model_class.columns.foo == foo`. Select related models have now an unique name for their path.
The name can be retrieved via `tables_and_models` or using `f"{hash_tablekey(...)}_{column}"`.
- Alter tables_and_models to use the prefix as key.
- Functions passed to filter functions reveive now the second positional parameter `tables_and_models`.
- Breaking: Alter tables_and_models to use the prefix as key with '' for the maintable and model.
- Breaking: Functions passed to filter functions reveive now a second positional parameter `tables_and_models`.
- `build_where_clause` conditionally returns exist subquery.
- Rename QueryType to QuerySetType. The old name stays as an alias.

### Fixed

Expand Down
4 changes: 2 additions & 2 deletions edgy/core/db/models/mixins/row.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, cast

from edgy.core.db.fields.base import RelationshipField
from edgy.core.db.models.utils import apply_instance_extras
Expand Down Expand Up @@ -275,7 +275,7 @@ async def __set_prefetch(
if queryset is None:
queryset = crawl_result.model_class.query.all()

queryset = queryset.select_related(crawl_result.reverse_path)
queryset = queryset.select_related(cast(str, crawl_result.reverse_path))
clause = {
f"{crawl_result.reverse_path}__{pkcol}": row._mapping[f"{row_prefix}{pkcol}"]
for pkcol in cls.pkcolumns
Expand Down
97 changes: 32 additions & 65 deletions edgy/core/db/querysets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@
from edgy.core.db.models.model_reference import ModelRef
from edgy.core.db.models.types import BaseModelType
from edgy.core.db.models.utils import apply_instance_extras
from edgy.core.db.querysets.mixins import QuerySetPropsMixin, TenancyMixin
from edgy.core.db.querysets.prefetch import Prefetch, PrefetchMixin, check_prefetch_collision
from edgy.core.db.querysets.types import EdgyEmbedTarget, EdgyModel, QueryType
from edgy.core.db.relationships.utils import crawl_relationship
from edgy.core.utils.db import check_db_connection, hash_tablekey
from edgy.core.utils.sync import run_sync
from edgy.exceptions import MultipleObjectsReturned, ObjectNotFound, QuerySetError
from edgy.types import Undefined

from . import clauses as clauses_mod
from .mixins import QuerySetPropsMixin, TenancyMixin
from .prefetch import Prefetch, PrefetchMixin, check_prefetch_collision
from .types import EdgyEmbedTarget, EdgyModel, QuerySetType, tables_and_models_type

if TYPE_CHECKING: # pragma: no cover
from databasez.core.transaction import Transaction
Expand All @@ -44,12 +44,10 @@
generic_field = BaseField()
_empty_set = cast(Sequence[Any], frozenset())

tables_and_models_type = dict[str, tuple["sqlalchemy.Table", type["BaseModelType"]]]


def get_table_key_or_name(table: Any) -> str:
def get_table_key_or_name(table: Union[sqlalchemy.Table, sqlalchemy.Alias]) -> str:
try:
return table.key
return table.key # type: ignore
except AttributeError:
# alias
return table.name
Expand Down Expand Up @@ -81,21 +79,11 @@ def clean_query_kwargs(
return new_kwargs


async def _parse_clause_arg(
arg: Any, instance: "BaseQuerySet", tables_and_models: tables_and_models_type
) -> Any:
if callable(arg):
arg = arg(instance, tables_and_models)
if isawaitable(arg):
arg = await arg
return arg


class BaseQuerySet(
TenancyMixin,
QuerySetPropsMixin,
PrefetchMixin,
QueryType,
QuerySetType,
):
"""Internal definitions for queryset."""

Expand Down Expand Up @@ -253,7 +241,7 @@ async def _resolve_clause_args(
) -> Any:
result: list[Any] = []
for arg in args:
result.append(_parse_clause_arg(arg, self, tables_and_models))
result.append(clauses_mod.parse_clause_arg(arg, self, tables_and_models))
if self.database.force_rollback:
return [await el for el in result]
else:
Expand All @@ -263,7 +251,7 @@ async def build_where_clause(
self, _: Any = None, tables_and_models: Optional[tables_and_models_type] = None
) -> Any:
"""Build a where clause from the filters which can be passed in a where function."""
joins = None
joins: Optional[Any] = None
if tables_and_models is None:
joins, tables_and_models = self._build_tables_join_from_relationship()
# ignored args for passing build_where_clause in filter_clauses
Expand Down Expand Up @@ -309,7 +297,7 @@ def _join_table_helper(
join_clause: Any,
current_transition: tuple[str, str, str],
*,
transitions: dict[tuple[str, str, str], tuple[Any, Optional[tuple[str, str]], str]],
transitions: dict[tuple[str, str, str], tuple[Any, Optional[tuple[str, str, str]], str]],
tables_and_models: dict[str, tuple["sqlalchemy.Table", type["BaseModelType"]]],
) -> Any:
if current_transition not in transitions:
Expand All @@ -333,9 +321,7 @@ def _join_table_helper(

def _build_tables_join_from_relationship(
self,
) -> tuple[
str, dict[str, tuple["sqlalchemy.Table", type["BaseModelType"]]], dict[str, set[str]], Any
]:
) -> tuple[Any, tables_and_models_type]:
"""
Builds the tables relationships and joins.
When a table contains more than one foreign key pointing to the same
Expand Down Expand Up @@ -427,7 +413,7 @@ def _build_tables_join_from_relationship(
reverse = True
if _select_prefix in _select_tables_and_models:
# use prexisting prefix
table = _select_tables_and_models[_select_prefix][0]
table: Any = _select_tables_and_models[_select_prefix][0]
else:
table = model_class.table_schema(self.active_schema)
table = table.alias(hash_tablekey(tablekey=table.key, prefix=prefix))
Expand Down Expand Up @@ -502,7 +488,7 @@ async def _as_select_with_tables(
"""
self._validate_only_and_defer()
joins, tables_and_models = self._build_tables_join_from_relationship()
columns = []
columns: list[Any] = []
for prefix, (table, model_class) in tables_and_models.items():
if not prefix:
for column_key, column in table.columns.items():
Expand Down Expand Up @@ -628,36 +614,12 @@ async def wrapper(
return fk_tuple.in_(await _sub_query)

clauses.append(wrapper)
elif callable(value):
# bind local vars
async def wrapper(
queryset: "QuerySet",
tables_and_models: tables_and_models_type,
*,
_field: "BaseFieldType" = field,
_value: Any = value,
_op: Optional[str] = op,
_prefix: str = related_str,
) -> Any:
_value = _value(queryset, tables_and_models)
if isawaitable(_value):
_value = await _value
table = tables_and_models[_prefix][0]
return _field.operator_to_clause(
_field.name,
_op,
table,
_value,
)

clauses.append(wrapper)

else:
assert not isinstance(
value, BaseModelType
), f"should be parsed in clean: {key}: {value}"

def wrapper(
async def wrapper(
queryset: "QuerySet",
tables_and_models: tables_and_models_type,
*,
Expand All @@ -666,6 +628,9 @@ def wrapper(
_op: Optional[str] = op,
_prefix: str = related_str,
) -> Any:
_value = await clauses_mod.parse_clause_arg(
_value, queryset, tables_and_models
)
table = tables_and_models[_prefix][0]
return _field.operator_to_clause(_field.name, _op, table, _value)

Expand Down Expand Up @@ -776,24 +741,26 @@ async def _handle_batch(
for row in batch
]
if prefetch_queryset is None:
prefetch_queryset = crawl_result.model_class.query.or_local(*clauses)
prefetch_queryset = crawl_result.model_class.query.local_or(*clauses)
else:
# ensure local or
prefetch_queryset = prefetch_queryset.or_local(*clauses)
prefetch_queryset = prefetch_queryset.local_or(*clauses)

if prefetch_queryset.model_class is self.model_class:
# queryset is of this model
prefetch_queryset = prefetch_queryset.select_related(prefetch.related_name)
prefetch_queryset.embed_parent = (prefetch.related_name, "")
else:
# queryset is of the target model
prefetch_queryset = prefetch_queryset.select_related(crawl_result.reverse_path)
prefetch_queryset = prefetch_queryset.select_related(
cast(str, crawl_result.reverse_path)
)
new_prefetch = Prefetch(
related_name=prefetch.related_name,
to_attr=prefetch.to_attr,
queryset=prefetch_queryset,
)
new_prefetch._bake_prefix = f"{hash_tablekey(tablekey=tables_and_models[''][0], prefix=crawl_result.reverse_path)}_"
new_prefetch._bake_prefix = f"{hash_tablekey(tablekey=tables_and_models[''][0].key, prefix=cast(str, crawl_result.reverse_path))}_"
new_prefetch._is_finished = True
_prefetch_related.append(new_prefetch)

Expand Down Expand Up @@ -899,7 +866,7 @@ def _filter_or_exclude(
Union[
"sqlalchemy.sql.expression.BinaryExpression",
Callable[
["QueryType"],
["QuerySetType"],
Union[
"sqlalchemy.sql.expression.BinaryExpression",
Awaitable["sqlalchemy.sql.expression.BinaryExpression"],
Expand All @@ -923,7 +890,7 @@ def _filter_or_exclude(
Union[
sqlalchemy.sql.expression.BinaryExpression,
Callable[
[QueryType],
[QuerySetType],
Union[
sqlalchemy.sql.expression.BinaryExpression,
Awaitable[sqlalchemy.sql.expression.BinaryExpression],
Expand All @@ -946,7 +913,7 @@ async def wrapper_and(
Union[
"sqlalchemy.sql.expression.BinaryExpression",
Callable[
["QueryType"],
["QuerySetType"],
Union[
"sqlalchemy.sql.expression.BinaryExpression",
Awaitable["sqlalchemy.sql.expression.BinaryExpression"],
Expand Down Expand Up @@ -1082,7 +1049,7 @@ def filter(
*clauses: Union[
"sqlalchemy.sql.expression.BinaryExpression",
Callable[
["QueryType"],
["QuerySetType"],
Union[
"sqlalchemy.sql.expression.BinaryExpression",
Awaitable["sqlalchemy.sql.expression.BinaryExpression"],
Expand Down Expand Up @@ -1112,7 +1079,7 @@ def or_(
*clauses: Union[
"sqlalchemy.sql.expression.BinaryExpression",
Callable[
["QueryType"],
["QuerySetType"],
Union[
"sqlalchemy.sql.expression.BinaryExpression",
Awaitable["sqlalchemy.sql.expression.BinaryExpression"],
Expand All @@ -1128,12 +1095,12 @@ def or_(
"""
return self._filter_or_exclude(clauses=clauses, or_=True, kwargs=kwargs)

def or_local(
def local_or(
self,
*clauses: Union[
"sqlalchemy.sql.expression.BinaryExpression",
Callable[
["QueryType"],
["QuerySetType"],
Union[
"sqlalchemy.sql.expression.BinaryExpression",
Awaitable["sqlalchemy.sql.expression.BinaryExpression"],
Expand All @@ -1156,7 +1123,7 @@ def and_(
*clauses: Union[
"sqlalchemy.sql.expression.BinaryExpression",
Callable[
["QueryType"],
["QuerySetType"],
Union[
"sqlalchemy.sql.expression.BinaryExpression",
Awaitable["sqlalchemy.sql.expression.BinaryExpression"],
Expand All @@ -1176,7 +1143,7 @@ def not_(
*clauses: Union[
"sqlalchemy.sql.expression.BinaryExpression",
Callable[
["QueryType"],
["QuerySetType"],
Union[
"sqlalchemy.sql.expression.BinaryExpression",
Awaitable["sqlalchemy.sql.expression.BinaryExpression"],
Expand All @@ -1197,7 +1164,7 @@ def exclude(
*clauses: Union[
"sqlalchemy.sql.expression.BinaryExpression",
Callable[
["QueryType"],
["QuerySetType"],
Union[
"sqlalchemy.sql.expression.BinaryExpression",
Awaitable["sqlalchemy.sql.expression.BinaryExpression"],
Expand Down
18 changes: 16 additions & 2 deletions edgy/core/db/querysets/clauses.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from inspect import isawaitable
from typing import TYPE_CHECKING, Any, Union

import sqlalchemy
Expand All @@ -8,6 +9,18 @@
if TYPE_CHECKING:
from edgy.core.db.models import Model

from .types import QuerySetType, tables_and_models_type


async def parse_clause_arg(
arg: Any, instance: QuerySetType, tables_and_models: tables_and_models_type
) -> Any:
if callable(arg):
arg = arg(instance, tables_and_models)
if isawaitable(arg):
arg = await arg
return arg


class _EnhancedClausesHelper:
def __init__(self, op: Any, default_empty: Any) -> None:
Expand All @@ -20,12 +33,13 @@ def __call__(self, *args: Any) -> Any:
return self.op(*args)

def from_kwargs(
self, columns_or_model: Union[Model, ColumnCollection], /, **kwargs: Any
self, columns_or_model: Union[Model, ColumnCollection, sqlalchemy.Table], /, **kwargs: Any
) -> Any:
# inferior to the kwargs parser of QuerySet
if not isinstance(columns_or_model, ColumnCollection) and hasattr(
columns_or_model, "columns"
):
columns_or_model = columns_or_model.table.columns
columns_or_model = columns_or_model.columns
return self.op(*(getattr(columns_or_model, item[0]) == item[1] for item in kwargs.items()))


Expand Down
Loading

0 comments on commit ec48eff

Please sign in to comment.