diff --git a/docs/release-notes.md b/docs/release-notes.md index 9675ac0d..0b43e26e 100644 --- a/docs/release-notes.md +++ b/docs/release-notes.md @@ -11,11 +11,17 @@ hide: - ComputedField. - Permission template. +- `reverse_clean` for ForeignKeys. +- Expanded filter methods of querysets (can pass now dict and querysets). -### Fixed +### Changed - Managers use now instance attributes (database, schema). +### Fixed + +- `select_related` works across ManyToMany fields. + ## 0.17.4 ### Fixed diff --git a/edgy/contrib/permissions/managers.py b/edgy/contrib/permissions/managers.py index 74d8a8bb..aa34b00c 100644 --- a/edgy/contrib/permissions/managers.py +++ b/edgy/contrib/permissions/managers.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Sequence -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Literal, cast from edgy.core.db.models.managers import Manager from edgy.core.db.models.types import BaseModelType @@ -17,19 +17,33 @@ def permissions_of(self, sources: Sequence[BaseModelType] | BaseModelType) -> Qu sources = [sources] if len(sources) == 0: # none - return cast(QuerySet, self.filter(and_())) + return cast("QuerySet", self.filter(and_())) UserField = self.owner.meta.fields["users"] + assert ( + UserField.embed_through is False or UserField.embed_through + ), "users field need embed_through=foo|False." GroupField = self.owner.meta.fields.get("groups", None) - query = cast(QuerySet, self.all()) + assert ( + GroupField is None or GroupField.embed_through is False or GroupField.embed_through + ), "groups field need embed_through=foo|False." + query = cast("QuerySet", self.all()) + + groups_field_user: Literal[False] | str = False + if GroupField is not None: + groups_field_user = GroupField.target.meta.fields[ + self.owner.users_field_group + ].reverse_name + assert isinstance( + groups_field_user, str + ), f"{GroupField.target} {self.owner.users_field_group} field needs reverse_name." for source in sources: if isinstance(source, UserField.target): - clause: dict[str, Any] = {"users": source} - if GroupField is not None: - clause[f"groups__{self.owner.users_field_group}"] = source + clause: dict[str, Any] = {"users__pk": source} + if groups_field_user: + clause[f"{groups_field_user}__{self.owner.users_field_group}__pk"] = source query = query.or_(**clause) - return query elif GroupField is not None and isinstance(source, GroupField.target): - query = query.or_(groups=source) + query = query.or_(groups__pk=source) else: raise ValueError(f"Invalid source: {source}.") return query @@ -37,66 +51,126 @@ def permissions_of(self, sources: Sequence[BaseModelType] | BaseModelType) -> Qu def users( self, permissions: Sequence[str] | str, - model_names: Sequence[str] | str | None = None, - objects: Sequence[BaseModelType] | BaseModelType | None = None, + model_names: Sequence[str | None] | str | None = None, + objects: Sequence[BaseModelType | None] | BaseModelType | None = None, + include_null_model_name: bool = True, + include_null_object: bool = True, ) -> QuerySet: if isinstance(permissions, str): permissions = [permissions] if isinstance(model_names, str): model_names = [model_names] + if model_names is not None and include_null_model_name: + model_names = [*model_names, None] if isinstance(objects, BaseModelType): objects = [objects] + if objects is not None and include_null_object: + objects = [*objects, None] UserField = self.owner.meta.fields["users"] + assert ( + UserField.embed_through is False or UserField.embed_through + ), "users field need embed_through=foo|False." GroupField = self.owner.meta.fields.get("groups", None) + assert ( + GroupField is None or GroupField.embed_through is False or GroupField.embed_through + ), "groups field need embed_through=foo|False." + assert ( + GroupField is None + or GroupField.target.meta.fields[self.owner.users_field_group].embed_through is False + or GroupField.target.meta.fields[self.owner.users_field_group].embed_through + ), f"{GroupField.target} {self.owner.users_field_group} field need embed_through=foo|False." ModelNameField = self.owner.meta.fields.get("model_name", None) + assert ( + ModelNameField is None + or ModelNameField.embed_through is False + or ModelNameField.embed_through + ), "model_name field need embed_through=foo|False." ContentTypeField = self.owner.meta.fields.get("obj", None) + assert ( + ContentTypeField is None + or ContentTypeField.embed_through is False + or ContentTypeField.embed_through + ), "obj field need embed_through=foo|False." if objects is not None and len(objects) == 0: # none return cast("QuerySet", UserField.target.query.filter(and_())) - clauses: dict[str, Any] = {f"{UserField.reverse_name}__name__in": permissions} + clauses: list[dict[str, Any]] = [{f"{UserField.reverse_name}__name__in": permissions}] if model_names is not None: if ModelNameField is not None: - clauses[f"{UserField.reverse_name}__model_name__in"] = model_names + clauses[-1][f"{UserField.reverse_name}__model_name__in"] = model_names elif ContentTypeField is not None: - clauses[f"{UserField.reverse_name}__obj__name__in"] = model_names + clauses[-1][f"{UserField.reverse_name}__obj__name__in"] = model_names + if GroupField is not None: - clauses[f"{self.owner.groups_field_user}__{GroupField.reverse_name}__name__in"] = ( - permissions - ) + clauses.append({}) + groups_field_user = GroupField.target.meta.fields[ + self.owner.users_field_group + ].reverse_name + assert isinstance( + groups_field_user, str + ), f"{GroupField.target} {self.owner.users_field_group} field needs reverse_name." + clauses[-1][f"{groups_field_user}__{GroupField.reverse_name}__name__in"] = permissions if model_names is not None: if ModelNameField is not None: - clauses[ - f"{self.owner.groups_field_user}__{GroupField.reverse_name}__model_name__in" + clauses[-1][ + f"{groups_field_user}__{GroupField.reverse_name}__model_name__in" ] = model_names elif ContentTypeField is not None: - clauses[ - f"{self.owner.groups_field_user}__{GroupField.reverse_name}__obj__name__in" + clauses[-1][ + f"{groups_field_user}__{GroupField.reverse_name}__obj__name__in" ] = model_names - query = cast("QuerySet", UserField.target.query.filter(**clauses)) + query = cast("QuerySet", UserField.target.query.or_(*clauses)) + if objects is not None: + obj_clauses = [] for obj in objects: - clause = {f"{UserField.reverse_name}__obj": obj} + obj_clauses.append({f"{UserField.reverse_name}__obj": obj}) if GroupField is not None: - clause[f"{self.owner.groups_field_user}__{GroupField.reverse_name}__obj"] = obj - query = query.or_(**clause) + obj_clauses[-1][ + f"{self.owner.groups_field_user}__{GroupField.reverse_name}__obj" + ] = obj + query = query.or_(*obj_clauses) return query def groups( self, permissions: Sequence[str] | str, - model_names: Sequence[str] | str | None = None, - objects: Sequence[BaseModelType] | BaseModelType | None = None, + model_names: Sequence[str | None] | str | None = None, + objects: Sequence[BaseModelType | None] | BaseModelType | None = None, + include_null_model_name: bool = True, + include_null_object: bool = True, ) -> QuerySet: if isinstance(permissions, str): permissions = [permissions] if isinstance(model_names, str): model_names = [model_names] + if model_names is not None and include_null_model_name: + model_names = [*model_names, None] if isinstance(objects, BaseModelType): objects = [objects] + if objects is not None and include_null_object: + objects = [*objects, None] GroupField = self.owner.meta.fields["groups"] + assert ( + GroupField.embed_through is False or GroupField.embed_through + ), "groups field need embed_through=foo|False." + assert ( + GroupField.target.meta.fields[self.owner.users_field_group].embed_through is False + or GroupField.target.meta.fields[self.owner.users_field_group].embed_through + ), f"{GroupField.target} {self.owner.users_field_group} field need embed_through=foo|False." ModelNameField = self.owner.meta.fields.get("model_name", None) + assert ( + ModelNameField is None + or ModelNameField.embed_through is False + or ModelNameField.embed_through + ), "model_name field need embed_through=foo|False." ContentTypeField = self.owner.meta.fields.get("obj", None) + assert ( + ContentTypeField is None + or ContentTypeField.embed_through is False + or ContentTypeField.embed_through + ), "obj field need embed_through=foo|False." if objects is not None and len(objects) == 0: # none return cast("QuerySet", GroupField.target.query.filter(and_())) diff --git a/edgy/contrib/permissions/models.py b/edgy/contrib/permissions/models.py index ee610aeb..52e1d90c 100644 --- a/edgy/contrib/permissions/models.py +++ b/edgy/contrib/permissions/models.py @@ -6,7 +6,6 @@ class BasePermission(edgy.Model): - groups_field_user: ClassVar[str] = "groups" users_field_group: ClassVar[str] = "users" name: str = edgy.fields.CharField(max_length=100, null=False) # model_name: str = edgy.fields.CharField(max_length=100, null=True) @@ -18,8 +17,9 @@ class BasePermission(edgy.Model): fallback_getter=lambda field, instance, owner: instance.name, ) - # users = edgy.fields.ManyToMany(User) - # groups = edgy.fields.ManyToMany(Group) + # Important: embed_through must be set for enabling full proxying + # users = edgy.fields.ManyToMany("User", embed_through=False) + # groups = edgy.fields.ManyToMany("Group", embed_through=False) query = PermissionManager() diff --git a/edgy/core/db/fields/base.py b/edgy/core/db/fields/base.py index c2a0fd88..6e4ce042 100644 --- a/edgy/core/db/fields/base.py +++ b/edgy/core/db/fields/base.py @@ -1,5 +1,6 @@ import contextlib import copy +from abc import abstractmethod from collections.abc import Sequence from functools import cached_property from typing import ( @@ -455,6 +456,9 @@ def is_cross_db(self, owner_database: Optional["Database"] = None) -> bool: owner_database = self.owner.database return str(owner_database.url) != str(self.target.database.url) + @abstractmethod + def reverse_clean(self, name: str, value: Any, for_query: bool = False) -> dict[str, Any]: ... + def expand_relationship(self, value: Any) -> Any: """ Returns the related object or the relationship object diff --git a/edgy/core/db/fields/core.py b/edgy/core/db/fields/core.py index 3fb9fe24..471f3539 100644 --- a/edgy/core/db/fields/core.py +++ b/edgy/core/db/fields/core.py @@ -42,7 +42,9 @@ def __init__( ] = None, **kwargs: Any, ) -> None: - kwargs["default"] = None + kwargs["exclude"] = True + kwargs["null"] = True + kwargs["primary_key"] = False kwargs["field_type"] = kwargs["annotation"] = Any self.getter = getter self.fallback_getter = fallback_getter @@ -92,6 +94,14 @@ def to_model( ) -> dict[str, Any]: return {} + def clean( + self, + name: str, + value: Any, + for_query: bool = False, + ) -> dict[str, Any]: + return {} + def __get__(self, instance: "BaseModelType", owner: Any = None) -> Any: return self.compute_getter(self, instance, owner) diff --git a/edgy/core/db/fields/foreign_keys.py b/edgy/core/db/fields/foreign_keys.py index 6a7cbbc3..8ae5f373 100644 --- a/edgy/core/db/fields/foreign_keys.py +++ b/edgy/core/db/fields/foreign_keys.py @@ -185,6 +185,32 @@ def clean(self, name: str, value: Any, for_query: bool = False) -> dict[str, Any raise ValueError(f"cannot handle: {value} of type {type(value)}") return retdict + def reverse_clean(self, name: str, value: Any, for_query: bool = False) -> dict[str, Any]: + if not for_query: + return {} + retdict: dict[str, Any] = {} + column_names = self.owner.meta.field_to_column_names[self.name] + assert len(column_names) >= 1 + if value is None: + for column_name in column_names: + retdict[self.from_fk_field_name(name, column_name)] = None + elif isinstance(value, dict): + for column_name in column_names: + translated_name = self.from_fk_field_name(name, column_name) + if translated_name in value: + retdict[translated_name] = value[translated_name] + elif isinstance(value, BaseModel): + for column_name in column_names: + translated_name = self.from_fk_field_name(name, column_name) + if hasattr(value, translated_name): + retdict[translated_name] = getattr(value, translated_name) + elif len(column_names) == 1: + translated_name = self.from_fk_field_name(name, next(iter(column_names))) + retdict[translated_name] = value + else: + raise ValueError(f"cannot handle: {value} of type {type(value)}") + return retdict + def modify_input(self, name: str, kwargs: dict[str, Any]) -> None: phase = CURRENT_PHASE.get() column_names = self.get_column_names(name) diff --git a/edgy/core/db/fields/many_to_many.py b/edgy/core/db/fields/many_to_many.py index 18c2e81e..1cfb7b65 100644 --- a/edgy/core/db/fields/many_to_many.py +++ b/edgy/core/db/fields/many_to_many.py @@ -71,6 +71,16 @@ def reverse_embed_through_prefix(self) -> str: return self.reverse_name return f"{self.reverse_name}__{self.embed_through}" + def clean(self, name: str, value: Any, for_query: bool = False) -> dict[str, Any]: + if not for_query: + return {} + raise NotImplementedError(f"Not implemented yet for ManyToMany {name}") + + def reverse_clean(self, name: str, value: Any, for_query: bool = False) -> dict[str, Any]: + if not for_query: + return {} + raise NotImplementedError(f"Not implemented yet for ManyToMany {name}") + def get_relation(self, **kwargs: Any) -> ManyRelationProtocol: assert not isinstance(self.through, str), "through not initialized yet" return ManyRelation( @@ -95,23 +105,43 @@ def get_reverse_relation(self, **kwargs: Any) -> ManyRelationProtocol: ) def traverse_field(self, path: str) -> tuple[Any, str, str]: - if self.embed_through_prefix and path.startswith(self.embed_through_prefix): + if self.embed_through_prefix is False or self.embed_through_prefix: + # select embedded + if self.embed_through_prefix is not False and path.startswith( + self.embed_through_prefix + ): + return ( + self.through, + self.from_foreign_key, + path.removeprefix(self.embed_through_prefix).removeprefix("__"), + ) + # proxy return ( - self.through, - self.from_foreign_key, - path.removeprefix(self.embed_through_prefix).removeprefix("__"), + self.target, + self.reverse_name, + f'{path.removeprefix(self.name).removeprefix("__")}', ) return self.target, self.reverse_name, path.removeprefix(self.name).removeprefix("__") def reverse_traverse_field_fk(self, path: str) -> tuple[Any, str, str]: # used for target fk - if self.reverse_embed_through_prefix and path.startswith( + if self.reverse_embed_through_prefix is False or path.startswith( self.reverse_embed_through_prefix ): + # select embedded + if self.reverse_embed_through_prefix and path.startswith( + self.reverse_embed_through_prefix + ): + return ( + self.through, + self.to_foreign_key, + path.removeprefix(self.reverse_embed_through_prefix).removeprefix("__"), + ) + # proxy return ( - self.through, - self.to_foreign_key, - path.removeprefix(self.reverse_embed_through_prefix).removeprefix("__"), + self.owner, + self.name, + f'{path.removeprefix(self.reverse_name).removeprefix("__")}', ) return self.owner, self.name, path.removeprefix(self.reverse_name).removeprefix("__") diff --git a/edgy/core/db/querysets/base.py b/edgy/core/db/querysets/base.py index 4c037880..dd4d0bc6 100644 --- a/edgy/core/db/querysets/base.py +++ b/edgy/core/db/querysets/base.py @@ -91,7 +91,6 @@ def __init__( model_class: Union[type[BaseModelType], None] = None, database: Union["Database", None] = None, filter_clauses: Any = None, - or_clauses: Any = None, select_related: Any = None, prefetch_related: Any = None, limit_count: Any = None, @@ -111,7 +110,7 @@ def __init__( ) -> None: super().__init__(model_class=model_class) self.filter_clauses = [] if filter_clauses is None else filter_clauses - self.or_clauses = [] if or_clauses is None else or_clauses + self.or_clauses: Any = [] self.limit_count = limit_count self._select_related = [] if select_related is None else select_related self._prefetch_related = [] if prefetch_related is None else prefetch_related @@ -161,22 +160,21 @@ async def _resolve_clause_args(self, args: Any) -> Any: else: return await asyncio.gather(*result) - async def build_where_clause(self) -> Any: + async def build_where_clause(self, _: Any = None) -> Any: """Build a where clause from the filters which can be passed in a where function.""" - build_where_clause: list[Any] = [] - + # ignored args for passing build_where_clause in filter_clauses + where_clause: list[Any] = [] if self.or_clauses: or_clauses = await self._resolve_clause_args(self.or_clauses) - build_where_clause.append( + where_clause.append( or_clauses[0] if len(or_clauses) == 1 else clauses_mod.or_(*or_clauses) ) if self.filter_clauses: # we AND by default - build_where_clause.extend(await self._resolve_clause_args(self.filter_clauses)) - # this simplifies the integration. - # otherwise unrolling is required which needs extra wrapping with async functions - return clauses_mod.and_(*build_where_clause) + where_clause.extend(await self._resolve_clause_args(self.filter_clauses)) + # for nicer unpacking + return clauses_mod.and_(*where_clause) def _build_select_distinct(self, distinct_on: Optional[Sequence[str]], expression: Any) -> Any: """Filters selects only specific fields. Leave empty to use simple distinct""" @@ -232,15 +230,37 @@ def _build_tables_select_from_relationship(self) -> Any: # now use the one of the model_class itself model_database = None table = model_class.table_schema(self.active_schema) - if table.name not in tables: - select_from = sqlalchemy.sql.join( # type: ignore - select_from, - table, - *self._select_from_relationship_clause_generator( - foreign_key, table, reverse, former_table - ), - ) - tables[table.name] = table + + if table.name in tables: + former_table = table + continue + if foreign_key.is_m2m: + # we need to inject the through model for the select + model_class = foreign_key.through + table = model_class.table_schema(self.active_schema) + if reverse: + select_path = f"{foreign_key.from_foreign_key}__{select_path}" + else: + select_path = f"{foreign_key.to_foreign_key}__{select_path}" + # if select_path is empty + select_path = select_path.removesuffix("__") + if table.name in tables: + former_table = table + continue + if reverse: + foreign_key = model_class.meta.fields[foreign_key.to_foreign_key] + else: + foreign_key = model_class.meta.fields[foreign_key.from_foreign_key] + reverse = True + + select_from = sqlalchemy.sql.join( # type: ignore + select_from, + table, + *self._select_from_relationship_clause_generator( + foreign_key, table, reverse, former_table + ), + ) + tables[table.name] = table former_table = table return tables.values(), select_from @@ -253,6 +273,7 @@ def _select_from_relationship_clause_generator( former_table: Any, ) -> Any: column_names = foreign_key.get_column_names(foreign_key.name) + assert column_names, f"foreign key without column names detected: {foreign_key.name}" for col in column_names: colname = foreign_key.from_fk_field_name(foreign_key.name, col) if reverse else col if reverse: @@ -338,28 +359,23 @@ async def _build_select(self) -> Any: return expression - def _filter_query( + def _kwargs_to_clauses( self, kwargs: Any, - exclude: bool = False, - or_: bool = False, - ) -> "QuerySet": + ) -> tuple[list[Any], list[str]]: clauses = [] - filter_clauses = self.filter_clauses - or_clauses = self.or_clauses select_related = list(self._select_related) - prefetch_related = list(self._prefetch_related) # Making sure for queries we use the main class and not the proxy # And enable the parent if self.model_class.__is_proxy_model__: self.model_class = self.model_class.__parent__ - kwargs = clean_query_kwargs( + cleaned_kwargs = clean_query_kwargs( self.model_class, kwargs, self.embed_parent_filters, model_database=self.database ) - for key, value in kwargs.items(): + for key, value in cleaned_kwargs.items(): model_class, field_name, op, related_str, _, cross_db_remainder = crawl_relationship( self.model_class, key ) @@ -419,45 +435,8 @@ async def wrapper( field_name, op, model_class.table_schema(self.active_schema), value ) ) - if exclude: - - async def wrapper(queryset: "QuerySet") -> Any: - return clauses_mod.not_( - clauses_mod.and_(*(await self._resolve_clause_args(clauses))) - ) - if not or_: - filter_clauses.append(wrapper) - else: - or_clauses.append(wrapper) - else: - if not or_: - filter_clauses += clauses - else: - or_clauses += clauses - - return cast( - "QuerySet", - self.__class__( - model_class=self.model_class, - database=self._database, - filter_clauses=filter_clauses, - or_clauses=or_clauses, - select_related=select_related, - prefetch_related=prefetch_related, - limit_count=self.limit_count, - limit_offset=self._offset, - batch_size=self._batch_size, - order_by=self._order_by, - only_fields=self._only, - defer_fields=self._defer, - embed_parent=self.embed_parent, - embed_parent_filters=self.embed_parent_filters, - table=getattr(self, "_table", None), - exclude_secrets=self._exclude_secrets, - using_schema=self.using_schema, - ), - ) + return clauses, select_related def _prepare_order_by(self, order_by: str) -> Any: reverse = order_by.startswith("-") @@ -546,10 +525,10 @@ def _clone(self) -> "QuerySet": queryset.active_schema = self.get_schema() queryset._table = getattr(self, "_table", None) - queryset.filter_clauses = copy.copy(self.filter_clauses) - queryset.or_clauses = copy.copy(self.or_clauses) + queryset.filter_clauses = list(self.filter_clauses) + queryset.or_clauses = list(self.or_clauses) queryset.limit_count = copy.copy(self.limit_count) - queryset._select_related = copy.copy(self._select_related) + queryset._select_related = list(self._select_related) queryset._prefetch_related = copy.copy(self._prefetch_related) queryset._offset = copy.copy(self._offset) queryset._order_by = copy.copy(self._order_by) @@ -735,6 +714,8 @@ def _filter_or_exclude( Awaitable["sqlalchemy.sql.expression.BinaryExpression"], ], ], + dict[str, Any], + "QuerySet", ] ], exclude: bool = False, @@ -745,14 +726,80 @@ def _filter_or_exclude( """ queryset: QuerySet = self._clone() if kwargs: - queryset = queryset._filter_query(kwargs, exclude=exclude, or_=or_) - if not clauses: + clauses = [*clauses, kwargs] + converted_clauses: Sequence[ + Union[ + sqlalchemy.sql.expression.BinaryExpression, + Callable[ + [QueryType], + Union[ + sqlalchemy.sql.expression.BinaryExpression, + Awaitable[sqlalchemy.sql.expression.BinaryExpression], + ], + ], + ] + ] = [] + for raw_clause in clauses: + if isinstance(raw_clause, dict): + extracted_clauses, queryset._select_related = queryset._kwargs_to_clauses( + kwargs=raw_clause + ) + if or_ and extracted_clauses: + + async def wrapper_and( + queryset: "QuerySet", + _extracted_clauses: Sequence[ + Union[ + "sqlalchemy.sql.expression.BinaryExpression", + Callable[ + ["QueryType"], + Union[ + "sqlalchemy.sql.expression.BinaryExpression", + Awaitable["sqlalchemy.sql.expression.BinaryExpression"], + ], + ], + ] + ] = extracted_clauses, + ) -> Any: + return clauses_mod.and_( + *(await self._resolve_clause_args(_extracted_clauses)) + ) + + if len(clauses) == 1: + # add to global or + assert not exclude + queryset.or_clauses.append(wrapper_and) + return queryset + converted_clauses.append(wrapper_and) + else: + converted_clauses.extend(extracted_clauses) + elif isinstance(raw_clause, QuerySet): + converted_clauses.append(raw_clause.build_where_clause) + for related in raw_clause._select_related: + if related not in queryset._select_related: + queryset._select_related.append(related) + + else: + converted_clauses.append(raw_clause) + if not converted_clauses: return queryset - op = clauses_mod.or_ if or_ else clauses_mod.and_ + if exclude: - queryset.filter_clauses.append(clauses_mod.not_(op(*clauses))) + op = clauses_mod.and_ if not or_ else clauses_mod.or_ + + async def wrapper(queryset: "QuerySet") -> Any: + return clauses_mod.not_(op(*(await self._resolve_clause_args(converted_clauses)))) + + queryset.filter_clauses.append(wrapper) + elif or_: + + async def wrapper(queryset: "QuerySet") -> Any: + return clauses_mod.or_(*(await self._resolve_clause_args(converted_clauses))) + + queryset.filter_clauses.append(wrapper) else: - queryset.filter_clauses.append(op(*clauses)) + # default to and + queryset.filter_clauses.extend(converted_clauses) return queryset async def _model_based_delete(self) -> int: @@ -831,6 +878,8 @@ def filter( Awaitable["sqlalchemy.sql.expression.BinaryExpression"], ], ], + dict[str, Any], + "QuerySet", ], **kwargs: Any, ) -> "QuerySet": @@ -859,6 +908,8 @@ def or_( Awaitable["sqlalchemy.sql.expression.BinaryExpression"], ], ], + dict[str, Any], + "QuerySet", ], **kwargs: Any, ) -> "QuerySet": @@ -878,13 +929,14 @@ def and_( Awaitable["sqlalchemy.sql.expression.BinaryExpression"], ], ], + dict[str, Any], ], **kwargs: Any, ) -> "QuerySet": """ Filters the QuerySet by the AND operand. Alias of filter. """ - return self.filter(*clauses, **kwargs) + return self._filter_or_exclude(clauses=clauses, kwargs=kwargs) def not_( self, @@ -897,6 +949,8 @@ def not_( Awaitable["sqlalchemy.sql.expression.BinaryExpression"], ], ], + dict[str, Any], + "QuerySet", ], **kwargs: Any, ) -> "QuerySet": @@ -916,6 +970,8 @@ def exclude( Awaitable["sqlalchemy.sql.expression.BinaryExpression"], ], ], + dict[str, Any], + "QuerySet", ], **kwargs: Any, ) -> "QuerySet": @@ -1062,8 +1118,7 @@ def select_related(self, related: Any) -> "QuerySet": if not isinstance(related, (list, tuple)): related = [related] - related = list(queryset._select_related) + related - queryset._select_related = related + queryset._select_related.extend(related) return queryset async def values( diff --git a/edgy/core/db/querysets/types.py b/edgy/core/db/querysets/types.py index bf278d5a..16509961 100644 --- a/edgy/core/db/querysets/types.py +++ b/edgy/core/db/querysets/types.py @@ -46,6 +46,8 @@ def filter( Awaitable["sqlalchemy.sql.expression.BinaryExpression"], ], ], + dict[str, Any], + "QueryType", ], **kwargs: Any, ) -> "QueryType": ... @@ -65,6 +67,7 @@ def or_( Awaitable["sqlalchemy.sql.expression.BinaryExpression"], ], ], + "QueryType", ], **kwargs: Any, ) -> "QueryType": @@ -84,6 +87,8 @@ def and_( Awaitable["sqlalchemy.sql.expression.BinaryExpression"], ], ], + dict[str, Any], + "QueryType", ], **kwargs: Any, ) -> "QueryType": @@ -103,6 +108,8 @@ def not_( Awaitable["sqlalchemy.sql.expression.BinaryExpression"], ], ], + dict[str, Any], + "QueryType", ], **kwargs: Any, ) -> "QueryType": @@ -123,6 +130,8 @@ def exclude( Awaitable["sqlalchemy.sql.expression.BinaryExpression"], ], ], + dict[str, Any], + "QueryType", ], **kwargs: Any, ) -> "QueryType": ... diff --git a/edgy/core/db/relationships/related_field.py b/edgy/core/db/relationships/related_field.py index 0f9cebcc..0e65ffb2 100644 --- a/edgy/core/db/relationships/related_field.py +++ b/edgy/core/db/relationships/related_field.py @@ -96,9 +96,7 @@ def is_m2m(self) -> bool: return self.foreign_key.is_m2m def clean(self, name: str, value: Any, for_query: bool = False) -> dict[str, Any]: - if not for_query: - return {} - return self.related_to.meta.pk.clean("pk", value, for_query=for_query) # type: ignore + return self.foreign_key.reverse_clean(name, value, for_query=for_query) def __repr__(self) -> str: return f"<{self.__class__.__name__}: {self}>" diff --git a/tests/clauses/test_or_clauses.py b/tests/clauses/test_or_clauses.py index 81537c5e..f912cb8e 100644 --- a/tests/clauses/test_or_clauses.py +++ b/tests/clauses/test_or_clauses.py @@ -160,6 +160,20 @@ async def test_filter_or_clause_select(): assert len(results) == 2 +async def test_filter_or_clause_select_new(): + user = await User.query.create(name="Adam", email="adam@edgy.dev") + await User.query.create(name="Edgy", email="adam@edgy.dev") + + results = await User.query.or_({"name": "Test"}, {"name": "Adam"}) + + assert len(results) == 1 + assert results[0].pk == user.pk + + results = await User.query.or_({"name": "Edgy"}, {"name": "Adam"}) + + assert len(results) == 2 + + async def test_filter_or_clause_mixed(): user = await User.query.create(name="Adam", email="adam@edgy.dev") await User.query.create(name="Edgy", email="adam@edgy.dev") diff --git a/tests/contrib/contenttypes/test_contenttypes.py b/tests/contrib/contenttypes/test_contenttypes.py index 4f21af2b..3e0708c4 100644 --- a/tests/contrib/contenttypes/test_contenttypes.py +++ b/tests/contrib/contenttypes/test_contenttypes.py @@ -62,7 +62,7 @@ async def create_test_database(): @pytest.fixture(autouse=True, scope="function") async def rollback_transactions(): - async with models.database: + async with models: yield diff --git a/tests/contrib/permissions/test_advanced_permissions.py b/tests/contrib/permissions/test_advanced_permissions.py new file mode 100644 index 00000000..6d035e25 --- /dev/null +++ b/tests/contrib/permissions/test_advanced_permissions.py @@ -0,0 +1,50 @@ +import pytest + +import edgy +from edgy.contrib.permissions import BasePermission +from edgy.testclient import DatabaseTestClient +from tests.settings import DATABASE_URL + +pytestmark = pytest.mark.anyio + +database = DatabaseTestClient(DATABASE_URL, use_existing=False) +models = edgy.Registry( + database=edgy.Database(database, force_rollback=True), with_content_type=True +) + + +class User(edgy.Model): + name = edgy.fields.CharField(max_length=100) + + class Meta: + registry = models + + +class Permission(BasePermission): + users = edgy.fields.ManyToMany("User", embed_through=False) + + class Meta: + registry = models + + +@pytest.fixture(autouse=True, scope="module") +async def create_test_database(): + async with database: + await models.create_all() + yield + if not database.drop: + await models.drop_all() + + +@pytest.fixture(autouse=True, scope="function") +async def rollback_transactions(): + async with models: + yield + + +async def test_querying(): + user = await User.query.create(name="edgy") + permission = await Permission.query.create(users=[user], name="view") + assert await Permission.query.users("view").get() == user + assert await Permission.query.users("edit").count() == 0 + assert await Permission.query.permissions_of(user).get() == permission diff --git a/tests/contrib/permissions/test_group_permissions.py b/tests/contrib/permissions/test_group_permissions.py new file mode 100644 index 00000000..368dbc22 --- /dev/null +++ b/tests/contrib/permissions/test_group_permissions.py @@ -0,0 +1,66 @@ +import pytest + +import edgy +from edgy.contrib.permissions import BasePermission +from edgy.testclient import DatabaseTestClient +from tests.settings import DATABASE_URL + +pytestmark = pytest.mark.anyio + +database = DatabaseTestClient(DATABASE_URL, use_existing=False) +models = edgy.Registry( + database=edgy.Database(database, force_rollback=True), with_content_type=True +) + + +class User(edgy.Model): + name = edgy.fields.CharField(max_length=100) + + class Meta: + registry = models + + +class Group(edgy.Model): + name = edgy.fields.CharField(max_length=100) + users = edgy.fields.ManyToMany("User", embed_through=False) + + class Meta: + registry = models + + +class Permission(BasePermission): + users = edgy.fields.ManyToMany("User", embed_through=False) + groups = edgy.fields.ManyToMany("Group", embed_through=False) + + class Meta: + registry = models + + +@pytest.fixture(autouse=True, scope="module") +async def create_test_database(): + async with database: + await models.create_all() + yield + if not database.drop: + await models.drop_all() + + +@pytest.fixture(autouse=True, scope="function") +async def rollback_transactions(): + async with models: + yield + + +async def test_querying(): + user = await User.query.create(name="edgy") + group = await Group.query.create(name="admin", users=[user]) + permission = await Permission.query.create(users=[user], name="view") + permission2 = await Permission.query.create(groups=[group], name="admin") + assert await Permission.query.filter(name="admin").get() + assert await Permission.query.permissions_of(user).get(name="view") == permission + permissions = await Permission.query.permissions_of(group) + assert permissions == [permission2] + assert await Permission.query.users("view").get() == user + assert await Permission.query.users("admin").get() == user + assert await Permission.query.users("edit").count() == 0 + assert await Permission.query.permissions_of(user).count() == 2 diff --git a/tests/contrib/permissions/test_simple_permissions.py b/tests/contrib/permissions/test_simple_permissions.py new file mode 100644 index 00000000..665548fc --- /dev/null +++ b/tests/contrib/permissions/test_simple_permissions.py @@ -0,0 +1,66 @@ +import pytest + +import edgy +from edgy.contrib.permissions import BasePermission +from edgy.testclient import DatabaseTestClient +from tests.settings import DATABASE_URL + +pytestmark = pytest.mark.anyio + +database = DatabaseTestClient(DATABASE_URL, use_existing=False) +models = edgy.Registry( + database=edgy.Database(database, force_rollback=True), with_content_type=True +) + + +class User(edgy.Model): + name = edgy.fields.CharField(max_length=100, unique=True) + + class Meta: + registry = models + + +class Permission(BasePermission): + users = edgy.fields.ManyToMany("User", embed_through=False) + + class Meta: + registry = models + unique_together = [("name",)] + + @classmethod + def get_description(cls, field, instance, owner=None) -> str: + return instance.name.upper() + + @classmethod + def set_description(cls, field, instance, value) -> None: + instance.__dict__["test"] = value + + +@pytest.fixture(autouse=True, scope="module") +async def create_test_database(): + async with database: + await models.create_all() + yield + if not database.drop: + await models.drop_all() + + +@pytest.fixture(autouse=True, scope="function") +async def rollback_transactions(): + async with models: + yield + + +async def test_permission(): + permission = await Permission.query.create(name="View") + assert permission.description == "VIEW" + permission.description = "toll" + assert permission.test == "toll" + + +async def test_querying(): + user = await User.query.create(name="edgy") + permission = await Permission.query.create(users=[user], name="view") + assert await Permission.query.users("view").get() == user + assert await Permission.query.users("edit").count() == 0 + assert await Permission.query.permissions_of(user).get() == permission