Skip to content

Commit

Permalink
Changes:
Browse files Browse the repository at this point in the history
- cleanup kwargs parsing of querysets
- drastically improve queryset filter methods
- add reverse_clean method and fix foreignfields
- fixes for permission and add tests (failing)
  • Loading branch information
devkral committed Oct 12, 2024
1 parent 4dcde82 commit cd2bf04
Show file tree
Hide file tree
Showing 15 changed files with 528 additions and 120 deletions.
8 changes: 7 additions & 1 deletion docs/release-notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,17 @@ hide:

- ComputedField.
- Permission template.
- `reverse_clean` for ForeignKeys.
- Expanded filter methods of querysets (can pass now dict and querysets).

### Fixed
### Changed

- Managers use now instance attributes (database, schema).

### Fixed

- `select_related` works across ManyToMany fields.

## 0.17.4

### Fixed
Expand Down
126 changes: 100 additions & 26 deletions edgy/contrib/permissions/managers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, Literal, cast

from edgy.core.db.models.managers import Manager
from edgy.core.db.models.types import BaseModelType
Expand All @@ -17,86 +17,160 @@ def permissions_of(self, sources: Sequence[BaseModelType] | BaseModelType) -> Qu
sources = [sources]
if len(sources) == 0:
# none
return cast(QuerySet, self.filter(and_()))
return cast("QuerySet", self.filter(and_()))
UserField = self.owner.meta.fields["users"]
assert (
UserField.embed_through is False or UserField.embed_through
), "users field need embed_through=foo|False."
GroupField = self.owner.meta.fields.get("groups", None)
query = cast(QuerySet, self.all())
assert (
GroupField is None or GroupField.embed_through is False or GroupField.embed_through
), "groups field need embed_through=foo|False."
query = cast("QuerySet", self.all())

groups_field_user: Literal[False] | str = False
if GroupField is not None:
groups_field_user = GroupField.target.meta.fields[
self.owner.users_field_group
].reverse_name
assert isinstance(
groups_field_user, str
), f"{GroupField.target} {self.owner.users_field_group} field needs reverse_name."
for source in sources:
if isinstance(source, UserField.target):
clause: dict[str, Any] = {"users": source}
if GroupField is not None:
clause[f"groups__{self.owner.users_field_group}"] = source
clause: dict[str, Any] = {"users__pk": source}
if groups_field_user:
clause[f"{groups_field_user}__{self.owner.users_field_group}__pk"] = source
query = query.or_(**clause)
return query
elif GroupField is not None and isinstance(source, GroupField.target):
query = query.or_(groups=source)
query = query.or_(groups__pk=source)
else:
raise ValueError(f"Invalid source: {source}.")
return query

def users(
self,
permissions: Sequence[str] | str,
model_names: Sequence[str] | str | None = None,
objects: Sequence[BaseModelType] | BaseModelType | None = None,
model_names: Sequence[str | None] | str | None = None,
objects: Sequence[BaseModelType | None] | BaseModelType | None = None,
include_null_model_name: bool = True,
include_null_object: bool = True,
) -> QuerySet:
if isinstance(permissions, str):
permissions = [permissions]
if isinstance(model_names, str):
model_names = [model_names]
if model_names is not None and include_null_model_name:
model_names = [*model_names, None]
if isinstance(objects, BaseModelType):
objects = [objects]
if objects is not None and include_null_object:
objects = [*objects, None]
UserField = self.owner.meta.fields["users"]
assert (
UserField.embed_through is False or UserField.embed_through
), "users field need embed_through=foo|False."
GroupField = self.owner.meta.fields.get("groups", None)
assert (
GroupField is None or GroupField.embed_through is False or GroupField.embed_through
), "groups field need embed_through=foo|False."
assert (
GroupField is None
or GroupField.target.meta.fields[self.owner.users_field_group].embed_through is False
or GroupField.target.meta.fields[self.owner.users_field_group].embed_through
), f"{GroupField.target} {self.owner.users_field_group} field need embed_through=foo|False."
ModelNameField = self.owner.meta.fields.get("model_name", None)
assert (
ModelNameField is None
or ModelNameField.embed_through is False
or ModelNameField.embed_through
), "model_name field need embed_through=foo|False."
ContentTypeField = self.owner.meta.fields.get("obj", None)
assert (
ContentTypeField is None
or ContentTypeField.embed_through is False
or ContentTypeField.embed_through
), "obj field need embed_through=foo|False."
if objects is not None and len(objects) == 0:
# none
return cast("QuerySet", UserField.target.query.filter(and_()))
clauses: dict[str, Any] = {f"{UserField.reverse_name}__name__in": permissions}
clauses: list[dict[str, Any]] = [{f"{UserField.reverse_name}__name__in": permissions}]
if model_names is not None:
if ModelNameField is not None:
clauses[f"{UserField.reverse_name}__model_name__in"] = model_names
clauses[-1][f"{UserField.reverse_name}__model_name__in"] = model_names
elif ContentTypeField is not None:
clauses[f"{UserField.reverse_name}__obj__name__in"] = model_names
clauses[-1][f"{UserField.reverse_name}__obj__name__in"] = model_names

if GroupField is not None:
clauses[f"{self.owner.groups_field_user}__{GroupField.reverse_name}__name__in"] = (
permissions
)
clauses.append({})
groups_field_user = GroupField.target.meta.fields[
self.owner.users_field_group
].reverse_name
assert isinstance(
groups_field_user, str
), f"{GroupField.target} {self.owner.users_field_group} field needs reverse_name."
clauses[-1][f"{groups_field_user}__{GroupField.reverse_name}__name__in"] = permissions
if model_names is not None:
if ModelNameField is not None:
clauses[
f"{self.owner.groups_field_user}__{GroupField.reverse_name}__model_name__in"
clauses[-1][
f"{groups_field_user}__{GroupField.reverse_name}__model_name__in"
] = model_names
elif ContentTypeField is not None:
clauses[
f"{self.owner.groups_field_user}__{GroupField.reverse_name}__obj__name__in"
clauses[-1][
f"{groups_field_user}__{GroupField.reverse_name}__obj__name__in"
] = model_names

query = cast("QuerySet", UserField.target.query.filter(**clauses))
query = cast("QuerySet", UserField.target.query.or_(*clauses))

if objects is not None:
obj_clauses = []
for obj in objects:
clause = {f"{UserField.reverse_name}__obj": obj}
obj_clauses.append({f"{UserField.reverse_name}__obj": obj})
if GroupField is not None:
clause[f"{self.owner.groups_field_user}__{GroupField.reverse_name}__obj"] = obj
query = query.or_(**clause)
obj_clauses[-1][
f"{self.owner.groups_field_user}__{GroupField.reverse_name}__obj"
] = obj
query = query.or_(*obj_clauses)
return query

def groups(
self,
permissions: Sequence[str] | str,
model_names: Sequence[str] | str | None = None,
objects: Sequence[BaseModelType] | BaseModelType | None = None,
model_names: Sequence[str | None] | str | None = None,
objects: Sequence[BaseModelType | None] | BaseModelType | None = None,
include_null_model_name: bool = True,
include_null_object: bool = True,
) -> QuerySet:
if isinstance(permissions, str):
permissions = [permissions]
if isinstance(model_names, str):
model_names = [model_names]
if model_names is not None and include_null_model_name:
model_names = [*model_names, None]
if isinstance(objects, BaseModelType):
objects = [objects]
if objects is not None and include_null_object:
objects = [*objects, None]
GroupField = self.owner.meta.fields["groups"]
assert (
GroupField.embed_through is False or GroupField.embed_through
), "groups field need embed_through=foo|False."
assert (
GroupField.target.meta.fields[self.owner.users_field_group].embed_through is False
or GroupField.target.meta.fields[self.owner.users_field_group].embed_through
), f"{GroupField.target} {self.owner.users_field_group} field need embed_through=foo|False."
ModelNameField = self.owner.meta.fields.get("model_name", None)
assert (
ModelNameField is None
or ModelNameField.embed_through is False
or ModelNameField.embed_through
), "model_name field need embed_through=foo|False."
ContentTypeField = self.owner.meta.fields.get("obj", None)
assert (
ContentTypeField is None
or ContentTypeField.embed_through is False
or ContentTypeField.embed_through
), "obj field need embed_through=foo|False."
if objects is not None and len(objects) == 0:
# none
return cast("QuerySet", GroupField.target.query.filter(and_()))
Expand Down
6 changes: 3 additions & 3 deletions edgy/contrib/permissions/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


class BasePermission(edgy.Model):
groups_field_user: ClassVar[str] = "groups"
users_field_group: ClassVar[str] = "users"
name: str = edgy.fields.CharField(max_length=100, null=False)
# model_name: str = edgy.fields.CharField(max_length=100, null=True)
Expand All @@ -18,8 +17,9 @@ class BasePermission(edgy.Model):
fallback_getter=lambda field, instance, owner: instance.name,
)

# users = edgy.fields.ManyToMany(User)
# groups = edgy.fields.ManyToMany(Group)
# Important: embed_through must be set for enabling full proxying
# users = edgy.fields.ManyToMany("User", embed_through=False)
# groups = edgy.fields.ManyToMany("Group", embed_through=False)

query = PermissionManager()

Expand Down
4 changes: 4 additions & 0 deletions edgy/core/db/fields/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import copy
from abc import abstractmethod
from collections.abc import Sequence
from functools import cached_property
from typing import (
Expand Down Expand Up @@ -455,6 +456,9 @@ def is_cross_db(self, owner_database: Optional["Database"] = None) -> bool:
owner_database = self.owner.database
return str(owner_database.url) != str(self.target.database.url)

@abstractmethod
def reverse_clean(self, name: str, value: Any, for_query: bool = False) -> dict[str, Any]: ...

def expand_relationship(self, value: Any) -> Any:
"""
Returns the related object or the relationship object
Expand Down
12 changes: 11 additions & 1 deletion edgy/core/db/fields/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def __init__(
] = None,
**kwargs: Any,
) -> None:
kwargs["default"] = None
kwargs["exclude"] = True
kwargs["null"] = True
kwargs["primary_key"] = False
kwargs["field_type"] = kwargs["annotation"] = Any
self.getter = getter
self.fallback_getter = fallback_getter
Expand Down Expand Up @@ -92,6 +94,14 @@ def to_model(
) -> dict[str, Any]:
return {}

def clean(
self,
name: str,
value: Any,
for_query: bool = False,
) -> dict[str, Any]:
return {}

def __get__(self, instance: "BaseModelType", owner: Any = None) -> Any:
return self.compute_getter(self, instance, owner)

Expand Down
26 changes: 26 additions & 0 deletions edgy/core/db/fields/foreign_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,32 @@ def clean(self, name: str, value: Any, for_query: bool = False) -> dict[str, Any
raise ValueError(f"cannot handle: {value} of type {type(value)}")
return retdict

def reverse_clean(self, name: str, value: Any, for_query: bool = False) -> dict[str, Any]:
if not for_query:
return {}
retdict: dict[str, Any] = {}
column_names = self.owner.meta.field_to_column_names[self.name]
assert len(column_names) >= 1
if value is None:
for column_name in column_names:
retdict[self.from_fk_field_name(name, column_name)] = None
elif isinstance(value, dict):
for column_name in column_names:
translated_name = self.from_fk_field_name(name, column_name)
if translated_name in value:
retdict[translated_name] = value[translated_name]
elif isinstance(value, BaseModel):
for column_name in column_names:
translated_name = self.from_fk_field_name(name, column_name)
if hasattr(value, translated_name):
retdict[translated_name] = getattr(value, translated_name)
elif len(column_names) == 1:
translated_name = self.from_fk_field_name(name, next(iter(column_names)))
retdict[translated_name] = value
else:
raise ValueError(f"cannot handle: {value} of type {type(value)}")
return retdict

def modify_input(self, name: str, kwargs: dict[str, Any]) -> None:
phase = CURRENT_PHASE.get()
column_names = self.get_column_names(name)
Expand Down
46 changes: 38 additions & 8 deletions edgy/core/db/fields/many_to_many.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ def reverse_embed_through_prefix(self) -> str:
return self.reverse_name
return f"{self.reverse_name}__{self.embed_through}"

def clean(self, name: str, value: Any, for_query: bool = False) -> dict[str, Any]:
if not for_query:
return {}
raise NotImplementedError(f"Not implemented yet for ManyToMany {name}")

def reverse_clean(self, name: str, value: Any, for_query: bool = False) -> dict[str, Any]:
if not for_query:
return {}
raise NotImplementedError(f"Not implemented yet for ManyToMany {name}")

def get_relation(self, **kwargs: Any) -> ManyRelationProtocol:
assert not isinstance(self.through, str), "through not initialized yet"
return ManyRelation(
Expand All @@ -95,23 +105,43 @@ def get_reverse_relation(self, **kwargs: Any) -> ManyRelationProtocol:
)

def traverse_field(self, path: str) -> tuple[Any, str, str]:
if self.embed_through_prefix and path.startswith(self.embed_through_prefix):
if self.embed_through_prefix is False or self.embed_through_prefix:
# select embedded
if self.embed_through_prefix is not False and path.startswith(
self.embed_through_prefix
):
return (
self.through,
self.from_foreign_key,
path.removeprefix(self.embed_through_prefix).removeprefix("__"),
)
# proxy
return (
self.through,
self.from_foreign_key,
path.removeprefix(self.embed_through_prefix).removeprefix("__"),
self.target,
self.reverse_name,
f'{path.removeprefix(self.name).removeprefix("__")}',
)
return self.target, self.reverse_name, path.removeprefix(self.name).removeprefix("__")

def reverse_traverse_field_fk(self, path: str) -> tuple[Any, str, str]:
# used for target fk
if self.reverse_embed_through_prefix and path.startswith(
if self.reverse_embed_through_prefix is False or path.startswith(
self.reverse_embed_through_prefix
):
# select embedded
if self.reverse_embed_through_prefix and path.startswith(
self.reverse_embed_through_prefix
):
return (
self.through,
self.to_foreign_key,
path.removeprefix(self.reverse_embed_through_prefix).removeprefix("__"),
)
# proxy
return (
self.through,
self.to_foreign_key,
path.removeprefix(self.reverse_embed_through_prefix).removeprefix("__"),
self.owner,
self.name,
f'{path.removeprefix(self.reverse_name).removeprefix("__")}',
)
return self.owner, self.name, path.removeprefix(self.reverse_name).removeprefix("__")

Expand Down
Loading

0 comments on commit cd2bf04

Please sign in to comment.