diff --git a/docs/queries/queries.md b/docs/queries/queries.md index decb654f..ba790294 100644 --- a/docs/queries/queries.md +++ b/docs/queries/queries.md @@ -262,6 +262,10 @@ users = await User.query.filter(User.columns.id.in_([1, 2, 3])) !!! Warning The `columns` refers to the columns of the underlying SQLAlchemy table. +!!! Warning + This works only for the main model of a QuerySet. Related are handled via `f"{hash_tablkey(tablekey=model.table.key, prefix=...)}_{columnkey}"`. + You can pass the column via `sqlalchemy.column` (lowercase column). + All the operations you would normally do in SQLAlchemy syntax, are allowed here. ##### Using `and_` and `or_` with kwargs @@ -277,7 +281,7 @@ users = await User.query.filter(and_.from_kwargs(User, name="foo", email="foo@ex users = await User.query.filter(and_.from_kwargs(User, **my_dict)) ``` -#### Global OR +#### OR Edgy QuerySet can do global ORs. This means you can attach new OR clauses also later. @@ -313,7 +317,7 @@ user_query = user_query.or_(user_query, {"email": "outlook"}, {"email": "gmail"} users = await user_query ``` -#### Passing multiple keyword based filters +##### Passing multiple keyword based filters You can also passing multiple keyword based filters by providing them as a dictionary @@ -322,7 +326,10 @@ user_query = User.query.or_({"active": True}, {"email": "outlook"}, {"email": "g # active users or users with email gmail or outlook are retrieved users = await user_query ``` +##### Local only OR +If the special mode of or_ is not wanted there is a function named `local_or`. It is similar +to the or_ function except it doesn't have the global OR mode. ### Limit @@ -1090,6 +1097,7 @@ The pendant in a model are `identifying_clauses`. query = Model.query.filter(id=1) # ensures that the db connection doesn't drop during operation async with query.database as database: + # when using joins a exist subquery is generated expression = query.table.select().where(await query.build_where_clause()) # as generic sql print(str(expression)) diff --git a/docs/release-notes.md b/docs/release-notes.md index 78153c2b..5370f097 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -6,17 +6,29 @@ hide: # Release Notes -## Unreleased +## 0.20.0 ### Added - Add DurationField. - Allow passing `max_digits` to FloatField. +- Add `local_or` function to QuerySets. + +### Changed + +- Only the main table of a queryset is queryable via `model_class.columns.foo == foo`. Select related models have now an unique name for their path. + The name can be retrieved via `tables_and_models` or using `f"{hash_tablekey(...)}_{column}"`. +- Breaking: Alter tables_and_models to use the prefix as key with '' for the maintable and model. +- Breaking: Functions passed to filter functions reveive now a second positional parameter `tables_and_models`. +- `build_where_clause` conditionally uses a subquery. +- Rename QueryType to QuerySetType. The old name stays as an alias. ### Fixed - Triggering load on non-existent field when reflecting. - InspectDB mapping was incorrect. +- Fix query edge cases. +- Fix using related queries with update/delete. ## 0.19.1 diff --git a/edgy/core/db/models/mixins/db.py b/edgy/core/db/models/mixins/db.py index ae61e935..18b3256f 100644 --- a/edgy/core/db/models/mixins/db.py +++ b/edgy/core/db/models/mixins/db.py @@ -235,13 +235,19 @@ def get_columns_for_name(self: "Model", name: str) -> Sequence["sqlalchemy.Colum else: return cast(Sequence["sqlalchemy.Column"], _empty) - def identifying_clauses(self) -> list[Any]: + def identifying_clauses(self, prefix: str = "") -> list[Any]: + # works only if the class of the model is the main class of the queryset + # TODO: implement prefix handling and return generic column without table attached + if prefix: + raise NotImplementedError() clauses: list[Any] = [] for field_name in self.identifying_db_fields: field = self.meta.fields.get(field_name) if field is not None: - for column, value in field.clean(field_name, self.__dict__[field_name]).items(): - clauses.append(getattr(self.table.columns, column) == value) + for column_name, value in field.clean( + field_name, self.__dict__[field_name] + ).items(): + clauses.append(getattr(self.table.columns, column_name) == value) else: clauses.append( getattr(self.table.columns, field_name) == self.__dict__[field_name] diff --git a/edgy/core/db/models/mixins/row.py b/edgy/core/db/models/mixins/row.py index 4918fc20..ecaee888 100644 --- a/edgy/core/db/models/mixins/row.py +++ b/edgy/core/db/models/mixins/row.py @@ -1,6 +1,6 @@ import asyncio from collections.abc import Sequence -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, cast from edgy.core.db.fields.base import RelationshipField from edgy.core.db.models.utils import apply_instance_extras @@ -28,10 +28,7 @@ def can_load_from_row(cls: type["Model"], row: "Row", table: "Table") -> bool: return bool( cls.meta.registry and not cls.meta.abstract - and all( - row._mapping.get(f"{table.key.replace('.', '_')}_{col}") is not None - for col in cls.pkcolumns - ) + and all(row._mapping.get(f"{table.name}_{col}") is not None for col in cls.pkcolumns) ) @classmethod @@ -87,17 +84,13 @@ async def from_sqla_row( detail=f'Selected field "{field_name}" is not a RelationshipField on {cls}.' ) from None + _prefix = field_name if not prefix else f"{prefix}__{field_name}" # stop selecting when None. Related models are not available. if not model_class.can_load_from_row( row, - tables_and_models[ - model_class.meta.tablename - if using_schema is None - else f"{using_schema}.{model_class.meta.tablename}" - ][0], + tables_and_models[_prefix][0], ): continue - _prefix = field_name if not prefix else f"{prefix}__{field_name}" if remainder: # don't pass table, it is only for the main model_class @@ -130,9 +123,7 @@ async def from_sqla_row( for k, v in item.items(): setattr(old_select_related_value, k, v) return old_select_related_value - table_columns = tables_and_models[ - cls.meta.tablename if using_schema is None else f"{using_schema}.{cls.meta.tablename}" - ][0].columns + table_columns = tables_and_models[prefix][0].columns # Populate the related names # Making sure if the model being queried is not inside a select related # This way it is not overritten by any value @@ -153,12 +144,15 @@ async def from_sqla_row( child_item = {} for column_name in columns_to_check: column = getattr(table_columns, column_name, None) - if ( - column is not None - and f"{column.table.key.replace('.', '_')}_{column.key}" in row._mapping - ): + if column_name is None: + continue + columnkeyhash = column_name + if prefix: + columnkeyhash = f"{tables_and_models[prefix][0].name}_{column.key}" + + if columnkeyhash in row._mapping: child_item[foreign_key.from_fk_field_name(related, column_name)] = ( - row._mapping[f"{column.table.key.replace('.', '_')}_{column.key}"] + row._mapping[columnkeyhash] ) # Make sure we generate a temporary reduced model # For the related fields. We simply chnage the structure of the model @@ -190,13 +184,14 @@ async def from_sqla_row( if column.key not in cls.meta.columns_to_field: continue # set if not of an foreign key with one column - elif ( - column.key not in item - and f"{column.table.key.replace('.', '_')}_{column.key}" in row._mapping - ): - item[column.key] = row._mapping[ - f"{column.table.key.replace('.', '_')}_{column.key}" - ] + if column.key in item: + continue + columnkeyhash = column.key + if prefix: + columnkeyhash = f"{tables_and_models[prefix][0].name}_{columnkeyhash}" + + if columnkeyhash in row._mapping: + item[column.key] = row._mapping[columnkeyhash] model: Model = ( cls.proxy_model(**item, __phase__="init_db") # type: ignore if exclude_secrets or is_defer_fields or only_fields @@ -208,24 +203,18 @@ async def from_sqla_row( cls, using_schema, database=database, - table=tables_and_models[ - cls.meta.tablename - if using_schema is None - else f"{using_schema}.{cls.meta.tablename}" - ][0], + table=tables_and_models[prefix][0], ) - # Handle prefetch related fields. - await cls.__handle_prefetch_related( - row=row, - table=tables_and_models[ - cls.meta.tablename - if using_schema is None - else f"{using_schema}.{cls.meta.tablename}" - ][0], - model=model, - prefetch_related=prefetch_related, - ) + if prefetch_related: + # Handle prefetch related fields. + await cls.__handle_prefetch_related( + row=row, + prefix=prefix, + model=model, + tables_and_models=tables_and_models, + prefetch_related=prefetch_related, + ) assert model.pk is not None, model return model @@ -243,27 +232,21 @@ def __should_ignore_related_name( return False @classmethod - def create_model_key_from_sqla_row( - cls, - row: "Row", - ) -> tuple: + def create_model_key_from_sqla_row(cls, row: "Row", row_prefix: str = "") -> tuple: """ Build a cache key for the model. """ pk_key_list: list[Any] = [cls.__name__] for attr in cls.pkcolumns: - try: - pk_key_list.append(str(row._mapping[getattr(cls.table.columns, attr)])) - except KeyError: - pk_key_list.append(str(row._mapping[attr])) + pk_key_list.append(str(row._mapping[f"{row_prefix}{attr}"])) return tuple(pk_key_list) @classmethod async def __set_prefetch( cls, row: "Row", - table: "Table", model: "Model", + row_prefix: str, related: "Prefetch", ) -> None: model_key = () @@ -273,48 +256,39 @@ async def __set_prefetch( if model_key in related._baked_results: setattr(model, related.to_attr, related._baked_results[model_key]) else: - clauses = [] - for pkcol in cls.pkcolumns: - clauses.append( - getattr(table.columns, pkcol) - == row._mapping[f"{table.key.replace('.', '_')}_{pkcol}"] + crawl_result = crawl_relationship( + model.__class__, related.related_name, traverse_last=True + ) + if crawl_result.reverse_path is False: + QuerySetError( + detail=("Creating a reverse path is not possible, unidirectional fields used.") + ) + if crawl_result.cross_db_remainder: + raise NotImplementedError( + "Cannot prefetch from other db yet. Maybe in future this feature will be added." ) queryset = related.queryset if related._is_finished: assert queryset is not None, "Queryset is not set but _is_finished flag" else: check_prefetch_collision(model, related) - crawl_result = crawl_relationship( - model.__class__, related.related_name, traverse_last=True - ) if queryset is None: - if crawl_result.reverse_path is False: - queryset = model.__class__.query.all() - else: - queryset = crawl_result.model_class.query.all() + queryset = crawl_result.model_class.query.all() - if queryset.model_class == model.__class__: - # queryset is of this model - queryset = queryset.select_related(related.related_name) - queryset.embed_parent = (related.related_name, "") - elif crawl_result.reverse_path is False: - QuerySetError( - detail=( - f"Creating a reverse path is not possible, unidirectional fields used." - f"You may want to use as queryset a queryset of model class {model!r}." - ) - ) - else: - # queryset is of the target model - queryset = queryset.select_related(crawl_result.reverse_path) - setattr(model, related.to_attr, await queryset.filter(*clauses)) + queryset = queryset.select_related(cast(str, crawl_result.reverse_path)) + clause = { + f"{crawl_result.reverse_path}__{pkcol}": row._mapping[f"{row_prefix}{pkcol}"] + for pkcol in cls.pkcolumns + } + setattr(model, related.to_attr, await queryset.filter(clause)) @classmethod async def __handle_prefetch_related( cls, row: "Row", - table: "Table", model: "Model", + prefix: str, + tables_and_models: dict[str, tuple["Table", type["BaseModelType"]]], prefetch_related: Sequence["Prefetch"], ) -> None: """ @@ -331,6 +305,9 @@ async def __handle_prefetch_related( # Check for conflicting names # Check as early as possible check_prefetch_collision(model=model, related=related) - queries.append(cls.__set_prefetch(row=row, table=table, model=model, related=related)) + row_prefix = f"{tables_and_models[prefix].name}_" if prefix else "" + queries.append( + cls.__set_prefetch(row=row, row_prefix=row_prefix, model=model, related=related) + ) if queries: await asyncio.gather(*queries) diff --git a/edgy/core/db/querysets/base.py b/edgy/core/db/querysets/base.py index aaf4b8e5..df514290 100644 --- a/edgy/core/db/querysets/base.py +++ b/edgy/core/db/querysets/base.py @@ -1,6 +1,5 @@ import asyncio import warnings -from collections import defaultdict from collections.abc import AsyncIterator, Awaitable, Generator, Iterable, Sequence from collections.abc import Iterable as CollectionsIterable from functools import cached_property @@ -24,16 +23,16 @@ from edgy.core.db.models.model_reference import ModelRef from edgy.core.db.models.types import BaseModelType from edgy.core.db.models.utils import apply_instance_extras -from edgy.core.db.querysets.mixins import QuerySetPropsMixin, TenancyMixin -from edgy.core.db.querysets.prefetch import Prefetch, PrefetchMixin, check_prefetch_collision -from edgy.core.db.querysets.types import EdgyEmbedTarget, EdgyModel, QueryType from edgy.core.db.relationships.utils import crawl_relationship -from edgy.core.utils.db import check_db_connection +from edgy.core.utils.db import check_db_connection, hash_tablekey from edgy.core.utils.sync import run_sync from edgy.exceptions import MultipleObjectsReturned, ObjectNotFound, QuerySetError from edgy.types import Undefined from . import clauses as clauses_mod +from .mixins import QuerySetPropsMixin, TenancyMixin +from .prefetch import Prefetch, PrefetchMixin, check_prefetch_collision +from .types import EdgyEmbedTarget, EdgyModel, QuerySetType, tables_and_models_type if TYPE_CHECKING: # pragma: no cover from databasez.core.transaction import Transaction @@ -46,6 +45,14 @@ _empty_set = cast(Sequence[Any], frozenset()) +def get_table_key_or_name(table: Union[sqlalchemy.Table, sqlalchemy.Alias]) -> str: + try: + return table.key # type: ignore + except AttributeError: + # alias + return table.name + + def clean_query_kwargs( model_class: type[BaseModelType], kwargs: dict[str, Any], @@ -72,19 +79,11 @@ def clean_query_kwargs( return new_kwargs -async def _parse_clause_arg(arg: Any, instance: "BaseQuerySet") -> Any: - if callable(arg): - arg = arg(instance) - if isawaitable(arg): - arg = await arg - return arg - - class BaseQuerySet( TenancyMixin, QuerySetPropsMixin, PrefetchMixin, - QueryType, + QuerySetType, ): """Internal definitions for queryset.""" @@ -118,6 +117,7 @@ def __init__( super().__init__(model_class=model_class) self.filter_clauses: list[Any] = list(filter_clauses) self.or_clauses: list[Any] = [] + self._aliases: dict[str, sqlalchemy.Alias] = {} if limit_count is not None: warnings.warn( "`limit_count` is deprecated use `limit`", DeprecationWarning, stacklevel=2 @@ -168,10 +168,8 @@ def __init__( # this is not cleared, because the expression is immutable self._cached_select_related_expression: Optional[ tuple[ - str, - dict[str, tuple[sqlalchemy.Table, type[BaseModelType]]], - dict[str, set[str]], Any, + dict[str, tuple[sqlalchemy.Table, type[BaseModelType]]], ] ] = None # initialize @@ -209,7 +207,7 @@ def _clone(self) -> "QuerySet": table=getattr(self, "_table", None), exclude_secrets=self._exclude_secrets, ) - queryset.or_clauses = list(self.or_clauses) + queryset.or_clauses.extend(self.or_clauses) queryset._cached_select_related_expression = self._cached_select_related_expression return cast("QuerySet", queryset) @@ -238,30 +236,58 @@ def _build_group_by_expression(self, group_by: Any, expression: Any) -> Any: expression = expression.group_by(*(self._prepare_order_by(entry) for entry in group_by)) return expression - async def _resolve_clause_args(self, args: Any) -> Any: + async def _resolve_clause_args( + self, args: Any, tables_and_models: tables_and_models_type + ) -> Any: result: list[Any] = [] for arg in args: - result.append(_parse_clause_arg(arg, self)) + result.append(clauses_mod.parse_clause_arg(arg, self, tables_and_models)) if self.database.force_rollback: return [await el for el in result] else: return await asyncio.gather(*result) - async def build_where_clause(self, _: Any = None) -> Any: + async def build_where_clause( + self, _: Any = None, tables_and_models: Optional[tables_and_models_type] = None + ) -> Any: """Build a where clause from the filters which can be passed in a where function.""" + joins: Optional[Any] = None + if tables_and_models is None: + joins, tables_and_models = self._build_tables_join_from_relationship() # ignored args for passing build_where_clause in filter_clauses - where_clause: list[Any] = [] + where_clauses: list[Any] = [] if self.or_clauses: - or_clauses = await self._resolve_clause_args(self.or_clauses) - where_clause.append( + or_clauses = await self._resolve_clause_args(self.or_clauses, tables_and_models) + where_clauses.append( or_clauses[0] if len(or_clauses) == 1 else clauses_mod.or_(*or_clauses) ) if self.filter_clauses: # we AND by default - where_clause.extend(await self._resolve_clause_args(self.filter_clauses)) + where_clauses.extend( + await self._resolve_clause_args(self.filter_clauses, tables_and_models) + ) # for nicer unpacking - return clauses_mod.and_(*where_clause) + if joins is None or len(tables_and_models) == 1: + return clauses_mod.and_(*where_clauses) + expression = sqlalchemy.sql.select( + *( + getattr(tables_and_models[""][0].c, col) + for col in tables_and_models[""][1].pkcolumns + ), + ).set_label_style(sqlalchemy.LABEL_STYLE_NONE) + idtuple = sqlalchemy.tuple_( + *( + getattr(tables_and_models[""][0].c, col) + for col in tables_and_models[""][1].pkcolumns + ) + ) + expression = expression.select_from(joins) + return idtuple.in_( + expression.where( + *where_clauses, + ) + ) def _build_select_distinct(self, distinct_on: Optional[Sequence[str]], expression: Any) -> Any: """Filters selects only specific fields. Leave empty to use simple distinct""" @@ -275,38 +301,33 @@ def _build_select_distinct(self, distinct_on: Optional[Sequence[str]], expressio def _join_table_helper( cls, join_clause: Any, - current_transition: tuple[str, str], + current_transition: tuple[str, str, str], *, - transitions: dict[tuple[str, str], tuple[Any, set[tuple[str, str]]]], + transitions: dict[tuple[str, str, str], tuple[Any, Optional[tuple[str, str, str]], str]], tables_and_models: dict[str, tuple["sqlalchemy.Table", type["BaseModelType"]]], - transitions_is_full_outer: dict[tuple[str, str], bool], ) -> Any: if current_transition not in transitions: return join_clause transition_value = transitions.pop(current_transition) - for dep in transition_value[1]: + if transition_value[1] is not None: join_clause = cls._join_table_helper( join_clause, - dep, + transition_value[1], transitions=transitions, tables_and_models=tables_and_models, - transitions_is_full_outer=transitions_is_full_outer, ) return sqlalchemy.sql.join( join_clause, - tables_and_models[current_transition[1]][0], + tables_and_models[transition_value[2]][0], transition_value[0], isouter=True, - full=transitions_is_full_outer.get(current_transition, False), ) - def _build_tables_select_from_relationship( + def _build_tables_join_from_relationship( self, - ) -> tuple[ - str, dict[str, tuple["sqlalchemy.Table", type["BaseModelType"]]], dict[str, set[str]], Any - ]: + ) -> tuple[Any, tables_and_models_type]: """ Builds the tables relationships and joins. When a table contains more than one foreign key pointing to the same @@ -315,7 +336,7 @@ def _build_tables_select_from_relationship( """ # How does this work? - # First we build a transitions table with dependencies in case multiple pathes to the same table exist + # First we build a transitions table with a dependency, so we find a path # Secondly we check if a select_related path is joining a table from the set in an opposite direction # If yes, we mark the transition for a full outer join (dangerous, there could be side-effects) # At last we iter through the transisitions and build their dependencies first @@ -326,13 +347,13 @@ def _build_tables_select_from_relationship( if self._cached_select_related_expression is None: maintable = self.table select_from = maintable - maintablekey = maintable.key - tables_and_models: dict[str, tuple[sqlalchemy.Table, type[BaseModelType]]] = { - select_from.key: (select_from, self.model_class) + tables_and_models: tables_and_models_type = {"": (select_from, self.model_class)} + _select_tables_and_models: tables_and_models_type = { + "": (select_from, self.model_class) } - prefixes: dict[str, set[str]] = defaultdict(set) - transitions: dict[tuple[str, str], tuple[Any, set[tuple[str, str]]]] = {} - transitions_is_full_outer: dict[tuple[str, str], bool] = {} + transitions: dict[ + tuple[str, str, str], tuple[Any, Optional[tuple[str, str, str]], str] + ] = {} # Select related for select_path in self._select_related: @@ -341,6 +362,8 @@ def _build_tables_select_from_relationship( former_table = maintable former_transition = None prefix: str = "" + _select_prefix: str = "" + injected_prefix: Union[bool, str] = False model_database: Optional[Database] = self.database while select_path: field_name = select_path.split("__", 1)[0] @@ -369,15 +392,20 @@ def _build_tables_select_from_relationship( ) # now use the one of the model_class itself model_database = None - table = model_class.table_schema(self.active_schema) - # use table from tables_and_models - if table.key in tables_and_models: - table = tables_and_models[table.key][0] - + if injected_prefix: + injected_prefix = False + else: + prefix = f"{prefix}__{field_name}" if prefix else f"{field_name}" + _select_prefix = ( + f"{_select_prefix}__{field_name}" if _select_prefix else f"{field_name}" + ) if foreign_key.is_m2m and foreign_key.embed_through != "": # type: ignore # we need to inject the through model for the select model_class = foreign_key.through - table = model_class.table_schema(self.active_schema) + if foreign_key.embed_through is False: + injected_prefix = True + else: + injected_prefix = f"{prefix}__{foreign_key.embed_through}" if reverse: select_path = f"{foreign_key.from_foreign_key}__{select_path}" else: @@ -389,9 +417,14 @@ def _build_tables_select_from_relationship( else: foreign_key = model_class.meta.fields[foreign_key.from_foreign_key] reverse = True - prefix = f"{prefix}__{field_name}" if prefix else f"{prefix}" - prefixes[table.key].add(prefix) - transition_key = (former_table.key, table.key) + if _select_prefix in _select_tables_and_models: + # use prexisting prefix + table: Any = _select_tables_and_models[_select_prefix][0] + else: + table = model_class.table_schema(self.active_schema) + table = table.alias(hash_tablekey(tablekey=table.key, prefix=prefix)) + + transition_key = (get_table_key_or_name(former_table), table.name, field_name) if transition_key in transitions: # can not provide new informations former_table = table @@ -402,37 +435,19 @@ def _build_tables_select_from_relationship( foreign_key, table, reverse, former_table ) ) - if (table.key, former_table.key) in transitions: - _transition_key = (table.key, former_table.key) - # inverted - # only make full outer when not the main query - if former_table.key != maintablekey: - transitions_is_full_outer[_transition_key] = True - transitions[_transition_key] = ( - clauses_mod.or_(transitions[_transition_key][0], and_clause), - {*transitions[_transition_key][1], former_transition} - if former_transition - else transitions[_transition_key][1], - ) - elif table.key in tables_and_models: - for _transition_key in transitions: - if _transition_key[1] == table.key: - break - else: - # this should never happen - raise Exception("transition not found despite in tables_and_models") - transitions[_transition_key] = ( - clauses_mod.or_(and_clause, transitions[_transition_key][0]), - {*transitions[_transition_key][1], former_transition} - if former_transition - else transitions[_transition_key][0], - ) - else: - transitions[(former_table.key, table.key)] = ( - and_clause, - {former_transition} if former_transition else set(), - ) - tables_and_models[table.key] = table, model_class + transitions[transition_key] = ( + and_clause, + former_transition, + _select_prefix, + ) + if injected_prefix is False: + tables_and_models[prefix] = table, model_class + elif injected_prefix is not True: + # we inject a string + tables_and_models[injected_prefix] = table, model_class + + # prefix used for select_related + _select_tables_and_models[_select_prefix] = table, model_class former_table = table former_transition = transition_key @@ -441,14 +456,11 @@ def _build_tables_select_from_relationship( select_from, next(iter(transitions.keys())), transitions=transitions, - tables_and_models=tables_and_models, - transitions_is_full_outer=transitions_is_full_outer, + tables_and_models=_select_tables_and_models, ) self._cached_select_related_expression = ( - maintablekey, - tables_and_models, - prefixes, select_from, + tables_and_models, ) return self._cached_select_related_expression @@ -476,94 +488,78 @@ def _validate_only_and_defer(self) -> None: async def _as_select_with_tables( self, - ) -> tuple[Any, dict[str, tuple["sqlalchemy.Table", type["BaseModelType"]]]]: + ) -> tuple[Any, tables_and_models_type]: """ Builds the query select based on the given parameters and filters. """ - queryset: BaseQuerySet = self - - queryset._validate_only_and_defer() - maintable, tables_and_models, prefixes, select_from = ( - queryset._build_tables_select_from_relationship() - ) - columns = [] - for tablekey, (table, model_class) in tables_and_models.items(): - if tablekey == maintable: + self._validate_only_and_defer() + joins, tables_and_models = self._build_tables_join_from_relationship() + columns: list[Any] = [] + for prefix, (table, model_class) in tables_and_models.items(): + if not prefix: for column_key, column in table.columns.items(): # e.g. reflection has not always a field field_name = model_class.meta.columns_to_field.get(column_key, column_key) - if queryset._only and field_name not in queryset._only: + if self._only and field_name not in self._only: continue - if queryset._defer and field_name in queryset._defer: + if self._defer and field_name in self._defer: continue if ( - queryset._exclude_secrets + self._exclude_secrets and field_name in model_class.meta.fields and model_class.meta.fields[field_name].secret ): continue - - # columns.append(column.label(f"{table.key.replace(".", "_")}_{column.key}")) + # add without alias columns.append(column) else: - prefixes_for_table = prefixes[table.key] for column_key, column in table.columns.items(): # e.g. reflection has not always a field field_name = model_class.meta.columns_to_field.get(column_key, column_key) - if queryset._only and all( - f"{prefix}" not in queryset._only - and f"{prefix}__{field_name}" not in queryset._only - for prefix in prefixes_for_table + if ( + self._only + and prefix not in self._only + and f"{prefix}__{field_name}" not in self._only ): continue - if queryset._defer and any( - f"{prefix}" in queryset._defer - or f"{prefix}__{field_name}" in queryset._defer - for prefix in prefixes_for_table + if self._defer and ( + prefix in self._defer or f"{prefix}__{field_name}" in self._defer ): continue if ( - queryset._exclude_secrets + self._exclude_secrets and field_name in model_class.meta.fields and model_class.meta.fields[field_name].secret ): continue - # columns.append(column.label(f"{table.key.replace(".", "_")}_{column.key}")) - columns.append(column) + # alias has name not a key. The name is fully descriptive + columns.append(column.label(f"{table.name}_{column_key}")) assert columns, "no columns specified" - # all columns are aliased now - expression = sqlalchemy.sql.select(*columns).set_label_style( - sqlalchemy.LABEL_STYLE_TABLENAME_PLUS_COL - ) - expression = expression.select_from(select_from) - expression = expression.where(await queryset.build_where_clause()) + # all columns are aliased already + expression = sqlalchemy.sql.select(*columns).set_label_style(sqlalchemy.LABEL_STYLE_NONE) + expression = expression.select_from(joins) + expression = expression.where(await self.build_where_clause(self, tables_and_models)) - if queryset._order_by: - expression = queryset._build_order_by_expression( - queryset._order_by, expression=expression - ) + if self._order_by: + expression = self._build_order_by_expression(self._order_by, expression=expression) - if queryset.limit_count: - expression = expression.limit(queryset.limit_count) + if self.limit_count: + expression = expression.limit(self.limit_count) - if queryset._offset: - expression = expression.offset(queryset._offset) + if self._offset: + expression = expression.offset(self._offset) - if queryset._group_by: - expression = queryset._build_group_by_expression( - queryset._group_by, expression=expression - ) + if self._group_by: + expression = self._build_group_by_expression(self._group_by, expression=expression) - if queryset.distinct_on is not None: - expression = queryset._build_select_distinct( - queryset.distinct_on, expression=expression - ) + if self.distinct_on is not None: + expression = self._build_select_distinct(self.distinct_on, expression=expression) return expression, tables_and_models async def as_select_with_tables( self, - ) -> tuple[Any, dict[str, tuple["sqlalchemy.Table", type["BaseModelType"]]]]: + ) -> tuple[Any, tables_and_models_type]: """ Builds the query select based on the given parameters and filters. """ @@ -611,48 +607,41 @@ def _kwargs_to_clauses( # bind local vars async def wrapper( queryset: "QuerySet", + tables_and_models: tables_and_models_type, + *, _field: "BaseFieldType" = field, _sub_query: "QuerySet" = sub_query, + _prefix: str = related_str, ) -> Any: + table = tables_and_models[_prefix][0] fk_tuple = sqlalchemy.tuple_( - *( - getattr(queryset.table.columns, colname) - for colname in _field.get_column_names() - ) + *(getattr(table.columns, colname) for colname in _field.get_column_names()) ) return fk_tuple.in_(await _sub_query) clauses.append(wrapper) - elif callable(value): - # bind local vars + else: + assert not isinstance( + value, BaseModelType + ), f"should be parsed in clean: {key}: {value}" + async def wrapper( queryset: "QuerySet", + tables_and_models: tables_and_models_type, + *, _field: "BaseFieldType" = field, _value: Any = value, _op: Optional[str] = op, + _prefix: str = related_str, ) -> Any: - _value = _value(queryset) - if isawaitable(_value): - _value = await _value - return _field.operator_to_clause( - _field.name, - _op, - queryset.model_class.table_schema(queryset.active_schema), - _value, + _value = await clauses_mod.parse_clause_arg( + _value, queryset, tables_and_models ) + table = tables_and_models[_prefix][0] + return _field.operator_to_clause(_field.name, _op, table, _value) clauses.append(wrapper) - else: - assert not isinstance( - value, BaseModelType - ), f"should be parsed in clean: {key}: {value}" - clauses.append( - field.operator_to_clause( - field_name, op, model_class.table_schema(self.active_schema), value - ) - ) - return clauses, select_related def _prepare_order_by(self, order_by: str) -> Any: @@ -734,16 +723,6 @@ async def _handle_batch( del queryset _prefetch_related: list[Prefetch] = [] - clauses = [] - for pkcol in self.model_class.pkcolumns: - clauses.append( - getattr(self.table.columns, pkcol).in_( - [ - row._mapping[f"{self.table.key.replace('.', '_', 1)}_{pkcol}"] - for row in batch - ] - ) - ) for prefetch in self._prefetch_related: check_prefetch_collision(self.model_class, prefetch) # type: ignore @@ -754,34 +733,40 @@ async def _handle_batch( raise NotImplementedError( "Cannot prefetch from other db yet. Maybe in future this feature will be added." ) + if crawl_result.reverse_path is False: + QuerySetError( + detail=("Creating a reverse path is not possible, unidirectional fields used.") + ) prefetch_queryset: Optional[QuerySet] = prefetch.queryset + + clauses = [ + { + f"{crawl_result.reverse_path}__{pkcol}": row._mapping[pkcol] + for pkcol in self.model_class.pkcolumns + } + for row in batch + ] if prefetch_queryset is None: - if crawl_result.reverse_path is False: - prefetch_queryset = self.model_class.query.filter(*clauses) - else: - prefetch_queryset = crawl_result.model_class.query.filter(*clauses) + prefetch_queryset = crawl_result.model_class.query.local_or(*clauses) else: - prefetch_queryset = prefetch_queryset.filter(*clauses) + # ensure local or + prefetch_queryset = prefetch_queryset.local_or(*clauses) - if prefetch_queryset.model_class == self.model_class: + if prefetch_queryset.model_class is self.model_class: # queryset is of this model prefetch_queryset = prefetch_queryset.select_related(prefetch.related_name) prefetch_queryset.embed_parent = (prefetch.related_name, "") - elif crawl_result.reverse_path is False: - QuerySetError( - detail=( - f"Creating a reverse path is not possible, unidirectional fields used." - f"You may want to use as queryset a queryset of model class {self.model_class!r}." - ) - ) else: # queryset is of the target model - prefetch_queryset = prefetch_queryset.select_related(crawl_result.reverse_path) + prefetch_queryset = prefetch_queryset.select_related( + cast(str, crawl_result.reverse_path) + ) new_prefetch = Prefetch( related_name=prefetch.related_name, to_attr=prefetch.to_attr, queryset=prefetch_queryset, ) + new_prefetch._bake_prefix = f"{hash_tablekey(tablekey=tables_and_models[''][0].key, prefix=cast(str, crawl_result.reverse_path))}_" new_prefetch._is_finished = True _prefetch_related.append(new_prefetch) @@ -887,7 +872,7 @@ def _filter_or_exclude( Union[ "sqlalchemy.sql.expression.BinaryExpression", Callable[ - ["QueryType"], + ["QuerySetType"], Union[ "sqlalchemy.sql.expression.BinaryExpression", Awaitable["sqlalchemy.sql.expression.BinaryExpression"], @@ -911,7 +896,7 @@ def _filter_or_exclude( Union[ sqlalchemy.sql.expression.BinaryExpression, Callable[ - [QueryType], + [QuerySetType], Union[ sqlalchemy.sql.expression.BinaryExpression, Awaitable[sqlalchemy.sql.expression.BinaryExpression], @@ -929,11 +914,12 @@ def _filter_or_exclude( async def wrapper_and( queryset: "QuerySet", + tables_and_models: tables_and_models_type, _extracted_clauses: Sequence[ Union[ "sqlalchemy.sql.expression.BinaryExpression", Callable[ - ["QueryType"], + ["QuerySetType"], Union[ "sqlalchemy.sql.expression.BinaryExpression", Awaitable["sqlalchemy.sql.expression.BinaryExpression"], @@ -943,7 +929,11 @@ async def wrapper_and( ] = extracted_clauses, ) -> Any: return clauses_mod.and_( - *(await self._resolve_clause_args(_extracted_clauses)) + *( + await self._resolve_clause_args( + _extracted_clauses, tables_and_models + ) + ) ) if allow_global_or and len(clauses) == 1: @@ -971,14 +961,22 @@ async def wrapper_and( if exclude: op = clauses_mod.and_ if not or_ else clauses_mod.or_ - async def wrapper(queryset: "QuerySet") -> Any: - return clauses_mod.not_(op(*(await self._resolve_clause_args(converted_clauses)))) + async def wrapper( + queryset: "QuerySet", tables_and_models: tables_and_models_type + ) -> Any: + return clauses_mod.not_( + op(*(await self._resolve_clause_args(converted_clauses, tables_and_models))) + ) queryset.filter_clauses.append(wrapper) elif or_: - async def wrapper(queryset: "QuerySet") -> Any: - return clauses_mod.or_(*(await self._resolve_clause_args(converted_clauses))) + async def wrapper( + queryset: "QuerySet", tables_and_models: tables_and_models_type + ) -> Any: + return clauses_mod.or_( + *(await self._resolve_clause_args(converted_clauses, tables_and_models)) + ) queryset.filter_clauses.append(wrapper) else: @@ -1034,6 +1032,9 @@ async def _get_raw(self, **kwargs: Any) -> tuple[BaseModelType, Any]: return await self._get_or_cache_row(rows[0], tables_and_models, "_cache_first,_cache_last") + def __repr__(self) -> str: + return f"QuerySet<{self.sql}>" + class QuerySet(BaseQuerySet): """ @@ -1043,14 +1044,18 @@ class QuerySet(BaseQuerySet): @cached_property def sql(self) -> str: """Get SQL select query as string.""" - return str(run_sync(self.as_select())) + return str( + run_sync(self.as_select()).compile( + compile_kwargs={"literal_binds": True}, + ) + ) def filter( self, *clauses: Union[ "sqlalchemy.sql.expression.BinaryExpression", Callable[ - ["QueryType"], + ["QuerySetType"], Union[ "sqlalchemy.sql.expression.BinaryExpression", Awaitable["sqlalchemy.sql.expression.BinaryExpression"], @@ -1080,7 +1085,7 @@ def or_( *clauses: Union[ "sqlalchemy.sql.expression.BinaryExpression", Callable[ - ["QueryType"], + ["QuerySetType"], Union[ "sqlalchemy.sql.expression.BinaryExpression", Awaitable["sqlalchemy.sql.expression.BinaryExpression"], @@ -1096,12 +1101,35 @@ def or_( """ return self._filter_or_exclude(clauses=clauses, or_=True, kwargs=kwargs) + def local_or( + self, + *clauses: Union[ + "sqlalchemy.sql.expression.BinaryExpression", + Callable[ + ["QuerySetType"], + Union[ + "sqlalchemy.sql.expression.BinaryExpression", + Awaitable["sqlalchemy.sql.expression.BinaryExpression"], + ], + ], + dict[str, Any], + "QuerySet", + ], + **kwargs: Any, + ) -> "QuerySet": + """ + Filters the QuerySet by the OR operand. + """ + return self._filter_or_exclude( + clauses=clauses, or_=True, kwargs=kwargs, allow_global_or=False + ) + def and_( self, *clauses: Union[ "sqlalchemy.sql.expression.BinaryExpression", Callable[ - ["QueryType"], + ["QuerySetType"], Union[ "sqlalchemy.sql.expression.BinaryExpression", Awaitable["sqlalchemy.sql.expression.BinaryExpression"], @@ -1121,7 +1149,7 @@ def not_( *clauses: Union[ "sqlalchemy.sql.expression.BinaryExpression", Callable[ - ["QueryType"], + ["QuerySetType"], Union[ "sqlalchemy.sql.expression.BinaryExpression", Awaitable["sqlalchemy.sql.expression.BinaryExpression"], @@ -1142,7 +1170,7 @@ def exclude( *clauses: Union[ "sqlalchemy.sql.expression.BinaryExpression", Callable[ - ["QueryType"], + ["QuerySetType"], Union[ "sqlalchemy.sql.expression.BinaryExpression", Awaitable["sqlalchemy.sql.expression.BinaryExpression"], diff --git a/edgy/core/db/querysets/clauses.py b/edgy/core/db/querysets/clauses.py index 9e59eff8..aeba7066 100644 --- a/edgy/core/db/querysets/clauses.py +++ b/edgy/core/db/querysets/clauses.py @@ -1,5 +1,6 @@ from __future__ import annotations +from inspect import isawaitable from typing import TYPE_CHECKING, Any, Union import sqlalchemy @@ -8,6 +9,18 @@ if TYPE_CHECKING: from edgy.core.db.models import Model + from .types import QuerySetType, tables_and_models_type + + +async def parse_clause_arg( + arg: Any, instance: QuerySetType, tables_and_models: tables_and_models_type +) -> Any: + if callable(arg): + arg = arg(instance, tables_and_models) + if isawaitable(arg): + arg = await arg + return arg + class _EnhancedClausesHelper: def __init__(self, op: Any, default_empty: Any) -> None: @@ -20,12 +33,13 @@ def __call__(self, *args: Any) -> Any: return self.op(*args) def from_kwargs( - self, columns_or_model: Union[Model, ColumnCollection], /, **kwargs: Any + self, columns_or_model: Union[Model, ColumnCollection, sqlalchemy.Table], /, **kwargs: Any ) -> Any: + # inferior to the kwargs parser of QuerySet if not isinstance(columns_or_model, ColumnCollection) and hasattr( columns_or_model, "columns" ): - columns_or_model = columns_or_model.table.columns + columns_or_model = columns_or_model.columns return self.op(*(getattr(columns_or_model, item[0]) == item[1] for item in kwargs.items())) diff --git a/edgy/core/db/querysets/mixins.py b/edgy/core/db/querysets/mixins.py index 40ed2a4a..0cfc2fad 100644 --- a/edgy/core/db/querysets/mixins.py +++ b/edgy/core/db/querysets/mixins.py @@ -99,8 +99,11 @@ def using( queryset.database = connection if schema is not Undefined: queryset.using_schema = schema if schema is not False else Undefined - queryset.active_schema = queryset.get_schema() - queryset.table = None + new_schema = queryset.get_schema() + if new_schema != queryset.active_schema: + queryset.active_schema = new_schema + queryset.table = None + return queryset def using_with_db( diff --git a/edgy/core/db/querysets/prefetch.py b/edgy/core/db/querysets/prefetch.py index 70cadc7f..daa99243 100644 --- a/edgy/core/db/querysets/prefetch.py +++ b/edgy/core/db/querysets/prefetch.py @@ -24,6 +24,7 @@ def __init__( self.to_attr = to_attr self.queryset: Optional[QuerySet] = queryset self._is_finished = False + self._bake_prefix: str = "" self._baked_results: dict[tuple[str, ...], list[Any]] = defaultdict(list) self._baked = False @@ -35,7 +36,7 @@ async def init_bake(self, model_class: type["Model"]) -> None: async for result in self.queryset._execute_iterate(True): # a bit hacky but we need the current row model_key = model_class.create_model_key_from_sqla_row( - self.queryset._cache_current_row + self.queryset._cache_current_row, row_prefix=self._bake_prefix ) self._baked_results[model_key].append(result) diff --git a/edgy/core/db/querysets/types.py b/edgy/core/db/querysets/types.py index dfd90be1..f71d8057 100644 --- a/edgy/core/db/querysets/types.py +++ b/edgy/core/db/querysets/types.py @@ -23,8 +23,10 @@ EdgyModel = TypeVar("EdgyModel", bound="BaseModelType") EdgyEmbedTarget = TypeVar("EdgyEmbedTarget") +tables_and_models_type = dict[str, tuple["sqlalchemy.Table", type["BaseModelType"]]] -class QueryType(ABC, Generic[EdgyEmbedTarget, EdgyModel]): + +class QuerySetType(ABC, Generic[EdgyEmbedTarget, EdgyModel]): __slots__ = ("model_class",) model_class: type[EdgyModel] @@ -40,20 +42,20 @@ def filter( *clauses: Union[ "sqlalchemy.sql.expression.BinaryExpression", Callable[ - ["QueryType"], + ["QuerySetType"], Union[ "sqlalchemy.sql.expression.BinaryExpression", Awaitable["sqlalchemy.sql.expression.BinaryExpression"], ], ], dict[str, Any], - "QueryType", + "QuerySetType", ], **kwargs: Any, - ) -> "QueryType": ... + ) -> "QuerySetType": ... @abstractmethod - def all(self, clear_cache: bool = False) -> "QueryType": ... + def all(self, clear_cache: bool = False) -> "QuerySetType": ... @abstractmethod def or_( @@ -61,16 +63,36 @@ def or_( *clauses: Union[ "sqlalchemy.sql.expression.BinaryExpression", Callable[ - ["QueryType"], + ["QuerySetType"], Union[ "sqlalchemy.sql.expression.BinaryExpression", Awaitable["sqlalchemy.sql.expression.BinaryExpression"], ], ], - "QueryType", + "QuerySetType", ], **kwargs: Any, - ) -> "QueryType": + ) -> "QuerySetType": + """ + Filters the QuerySet by the OR operand. + """ + + @abstractmethod + def local_or( + self, + *clauses: Union[ + "sqlalchemy.sql.expression.BinaryExpression", + Callable[ + ["QuerySetType"], + Union[ + "sqlalchemy.sql.expression.BinaryExpression", + Awaitable["sqlalchemy.sql.expression.BinaryExpression"], + ], + ], + "QuerySetType", + ], + **kwargs: Any, + ) -> "QuerySetType": """ Filters the QuerySet by the OR operand. """ @@ -81,17 +103,17 @@ def and_( *clauses: Union[ "sqlalchemy.sql.expression.BinaryExpression", Callable[ - ["QueryType"], + ["QuerySetType"], Union[ "sqlalchemy.sql.expression.BinaryExpression", Awaitable["sqlalchemy.sql.expression.BinaryExpression"], ], ], dict[str, Any], - "QueryType", + "QuerySetType", ], **kwargs: Any, - ) -> "QueryType": + ) -> "QuerySetType": """ Filters the QuerySet by the AND operand. Alias of filter. """ @@ -102,17 +124,17 @@ def not_( *clauses: Union[ "sqlalchemy.sql.expression.BinaryExpression", Callable[ - ["QueryType"], + ["QuerySetType"], Union[ "sqlalchemy.sql.expression.BinaryExpression", Awaitable["sqlalchemy.sql.expression.BinaryExpression"], ], ], dict[str, Any], - "QueryType", + "QuerySetType", ], **kwargs: Any, - ) -> "QueryType": + ) -> "QuerySetType": """ Filters the QuerySet by the NOT operand. Alias of exclude. """ @@ -124,47 +146,47 @@ def exclude( *clauses: Union[ "sqlalchemy.sql.expression.BinaryExpression", Callable[ - ["QueryType"], + ["QuerySetType"], Union[ "sqlalchemy.sql.expression.BinaryExpression", Awaitable["sqlalchemy.sql.expression.BinaryExpression"], ], ], dict[str, Any], - "QueryType", + "QuerySetType", ], **kwargs: Any, - ) -> "QueryType": ... + ) -> "QuerySetType": ... @abstractmethod - def lookup(self, term: Any) -> "QueryType": ... + def lookup(self, term: Any) -> "QuerySetType": ... @abstractmethod - def order_by(self, *columns: str) -> "QueryType": ... + def order_by(self, *columns: str) -> "QuerySetType": ... @abstractmethod - def reverse(self) -> "QueryType": ... + def reverse(self) -> "QuerySetType": ... @abstractmethod - def limit(self, limit_count: int) -> "QueryType": ... + def limit(self, limit_count: int) -> "QuerySetType": ... @abstractmethod - def offset(self, offset: int) -> "QueryType": ... + def offset(self, offset: int) -> "QuerySetType": ... @abstractmethod - def group_by(self, *group_by: str) -> "QueryType": ... + def group_by(self, *group_by: str) -> "QuerySetType": ... @abstractmethod - def distinct(self, *distinct_on: Sequence[str]) -> "QueryType": ... + def distinct(self, *distinct_on: Sequence[str]) -> "QuerySetType": ... @abstractmethod - def select_related(self, *related: str) -> "QueryType": ... + def select_related(self, *related: str) -> "QuerySetType": ... @abstractmethod - def only(self, *fields: str) -> "QueryType": ... + def only(self, *fields: str) -> "QuerySetType": ... @abstractmethod - def defer(self, *fields: str) -> "QueryType": ... + def defer(self, *fields: str) -> "QuerySetType": ... @abstractmethod async def exists(self) -> bool: ... @@ -244,10 +266,13 @@ def using( *, database: Union[str, Any, None, "Database"] = Undefined, schema: Union[str, Any, None, Literal[False]] = Undefined, - ) -> "QueryType": ... + ) -> "QuerySetType": ... @abstractmethod def __await__(self) -> Generator[Any, None, list[EdgyEmbedTarget]]: ... @abstractmethod async def __aiter__(self) -> AsyncIterator[EdgyEmbedTarget]: ... + + +QueryType = QuerySetType diff --git a/edgy/core/utils/db.py b/edgy/core/utils/db.py index 4ccf7478..3dd180c5 100644 --- a/edgy/core/utils/db.py +++ b/edgy/core/utils/db.py @@ -1,4 +1,7 @@ +import base64 +import hashlib import warnings +from functools import lru_cache from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -16,3 +19,20 @@ def check_db_connection(db: "Database") -> None: UserWarning, stacklevel=2, ) + + +@lru_cache(512, typed=False) +def _hash_tablekey(tablekey: str, prefix: str) -> str: + tablehash = ( + base64.urlsafe_b64encode(hashlib.new("md5", f"{tablekey}_{prefix}".encode()).digest()) + .decode() + .rstrip("=") + ) + + return f"_join_{tablehash}" + + +def hash_tablekey(*, tablekey: str, prefix: str) -> str: + if not prefix: + return tablekey + return _hash_tablekey(tablekey, prefix) diff --git a/tests/exclude_secrets/test_exclude_nested.py b/tests/exclude_secrets/test_exclude_nested.py index 164e840f..c3e28acb 100644 --- a/tests/exclude_secrets/test_exclude_nested.py +++ b/tests/exclude_secrets/test_exclude_nested.py @@ -3,6 +3,7 @@ import pytest import edgy +from edgy.core.utils.db import hash_tablekey from edgy.testclient import DatabaseTestClient from tests.settings import DATABASE_URL @@ -54,8 +55,8 @@ async def test_exclude_secrets_excludes_top_name_equals_to_name_in_foreignkey_no Organisation.query.select_related("user__profile").exclude_secrets().order_by("id") ).as_select() org_query_text = str(org_query) - assert "profiles.name" in org_query_text - assert 'users".name' not in org_query_text + assert f"{hash_tablekey(tablekey='profiles', prefix='user__profile')}_name" in org_query_text + assert f"{hash_tablekey(tablekey='users', prefix='user')}_name" not in org_query_text async def test_exclude_secrets_excludes_top_name_equals_to_name_in_foreignkey_not_secret(): diff --git a/tests/models/run_sync/test_models_filter.py b/tests/models/run_sync/test_models_filter.py index d7c9cf78..a7a5e9aa 100644 --- a/tests/models/run_sync/test_models_filter.py +++ b/tests/models/run_sync/test_models_filter.py @@ -107,10 +107,10 @@ async def test_model_filter(): products = Product.query.exclude(name__icontains="%") assert edgy.run_sync(products.count()) == 3 # test lambda filters - products = Product.query.exclude(name__contains=lambda x: "%") + products = Product.query.exclude(name__contains=lambda x, y: "%") assert edgy.run_sync(products.count()) == 3 - async def custom_filter(x): + async def custom_filter(x, y): return "%" products = Product.query.exclude(name__contains=custom_filter) diff --git a/tests/models/test_model_queryset_update.py b/tests/models/test_model_queryset_update.py index 033480ab..fab43249 100644 --- a/tests/models/test_model_queryset_update.py +++ b/tests/models/test_model_queryset_update.py @@ -27,7 +27,14 @@ class Product(edgy.Model): class Meta: registry = models - name = "products" + + +class ProductTag(edgy.Model): + product = edgy.fields.ForeignKey(Product, related_name="tags", null=True) + tag = edgy.fields.CharField(max_length=30) + + class Meta: + registry = models @pytest.fixture(autouse=True, scope="module") @@ -59,6 +66,18 @@ async def test_queryset_update(): assert tie.rating == 3 +async def test_queryset_update_via_related(): + shirt = await Product.query.create(name="Shirt", rating=5, tags=[ProductTag(tag="foo")]) + tie = await Product.query.create(name="Tie", rating=4, tags=[ProductTag(tag="faa")]) + assert tie.rating == 4 + assert shirt.rating == 5 + await Product.query.filter(tags__tag="foo").update(rating=1) + await tie.load() + await shirt.load() + assert shirt.rating == 1 + assert tie.rating == 4 + + async def test_model_update_or_create(): user, created = await User.query.update_or_create( name="Test", language="English", defaults={"name": "Jane"} diff --git a/tests/models/test_models_filter.py b/tests/models/test_models_filter.py index d2300f33..5c43b6ef 100644 --- a/tests/models/test_models_filter.py +++ b/tests/models/test_models_filter.py @@ -107,10 +107,10 @@ async def test_model_filter(): products = Product.query.exclude(name__icontains="%") assert await products.count() == 3 # test lambda filters - products = Product.query.exclude(name__contains=lambda x: "%") + products = Product.query.exclude(name__contains=lambda x, y: "%") assert await products.count() == 3 - async def custom_filter(x): + async def custom_filter(x, y): return "%" products = Product.query.exclude(name__contains=custom_filter)