Skip to content

Commit

Permalink
Changes
Browse files Browse the repository at this point in the history
- fix typings
- pass through tables used for select, so there should be a correct
  mapping
  • Loading branch information
devkral committed Oct 13, 2024
1 parent d9b4b6b commit e1f9418
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 44 deletions.
2 changes: 1 addition & 1 deletion edgy/contrib/permissions/managers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def permissions_of(self, sources: Sequence[BaseModelType] | BaseModelType) -> Qu
GroupField is None or GroupField.embed_through is False or GroupField.embed_through
), "groups field need embed_through=foo|False."

clauses: [dict[str, Any]] = []
clauses: list[dict[str, Any]] = []
for source in sources:
if isinstance(source, UserField.target):
clauses.append({"users__pk": source})
Expand Down
55 changes: 35 additions & 20 deletions edgy/core/db/models/mixins/row.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sqlalchemy.engine.result import Row

from edgy import Database, Model
from edgy.core.db.models.types import BaseModelType


class ModelRowMixin:
Expand All @@ -21,30 +22,21 @@ class ModelRowMixin:
"""

@classmethod
def can_load_from_row(
cls: type["Model"],
row: "Row",
table: Optional["Table"] = None,
using_schema: Optional[str] = None,
) -> bool:
if table is None:
table = cls.table_schema(using_schema)
# FIXME: the extraction of cols does not work because of clashing names new names are assigned
# find a way to get the original name or control the select to use speaking aliases
def can_load_from_row(cls: type["Model"], row: "Row", table: "Table") -> bool:
return bool(
cls.meta.registry is not None
and not cls.meta.abstract
and all(
row._mapping.get(getattr(table.columns, col)) is not None
and row._mapping.get(col) is not None
for col in cls.pkcolumns
row._mapping.get(getattr(table.columns, col)) is not None for col in cls.pkcolumns
)
)

@classmethod
async def from_sqla_row(
cls: type["Model"],
row: "Row",
# contain the mappings used for select
tables_and_models: dict[str, tuple["Table", type["BaseModelType"]]],
select_related: Optional[Sequence[Any]] = None,
prefetch_related: Optional[Sequence["Prefetch"]] = None,
is_only_fields: bool = False,
Expand All @@ -53,8 +45,6 @@ async def from_sqla_row(
exclude_secrets: bool = False,
using_schema: Optional[str] = None,
database: Optional["Database"] = None,
# local only parameter
table: Optional["Table"] = None,
) -> Optional["Model"]:
"""
Class method to convert a SQLAlchemy Row result into a EdgyModel row type.
Expand Down Expand Up @@ -93,12 +83,20 @@ async def from_sqla_row(
detail=f'Selected field "{field_name}" is not a RelationshipField on {cls}.'
) from None

if not model_class.can_load_from_row(row, using_schema=using_schema):
if not model_class.can_load_from_row(
row,
tables_and_models[
model_class.meta.tablename
if using_schema is None
else f"{using_schema}.{model_class.meta.tablename}"
][0],
):
continue
if remainder:
# don't pass table, it is only for the main model_class
item[field_name] = await model_class.from_sqla_row(
row,
tables_and_models=tables_and_models,
select_related=[remainder],
prefetch_related=prefetch_related,
exclude_secrets=exclude_secrets,
Expand All @@ -109,11 +107,14 @@ async def from_sqla_row(
# don't pass table, it is only for the main model_class
item[field_name] = await model_class.from_sqla_row(
row,
tables_and_models=tables_and_models,
exclude_secrets=exclude_secrets,
using_schema=using_schema,
database=database,
)
table_columns = cls.table_schema(using_schema).columns
table_columns = tables_and_models[
cls.meta.tablename if using_schema is None else f"{using_schema}.{cls.meta.tablename}"
][0].columns
# Populate the related names
# Making sure if the model being queried is not inside a select related
# This way it is not overritten by any value
Expand Down Expand Up @@ -142,10 +143,13 @@ async def from_sqla_row(
# For the related fields. We simply chnage the structure of the model
# and rebuild it with the new fields.
proxy_model = model_related.proxy_model(**child_item)
# don't pass table, it is only for the main model_class
proxy_database = database if model_related.database is cls.database else None
# don't pass a table. It is not in the row (select related path) and has not an explicit table
proxy_model = apply_instance_extras(
proxy_model, model_related, using_schema, database=proxy_database
proxy_model,
model_related,
using_schema,
database=proxy_database,
)
proxy_model.identifying_db_fields = foreign_key.related_columns

Expand All @@ -169,6 +173,7 @@ async def from_sqla_row(
if column in row._mapping:
item[column.key] = row._mapping[column]
elif column.name in row._mapping:
# FIXME: this path should not happen, we use the right tables
# fallback, sometimes the column is not found
item[column.key] = row._mapping[column.name]
model: Model = (
Expand All @@ -177,7 +182,17 @@ async def from_sqla_row(
else cls.proxy_model(**item)
)
# Apply the schema to the model
model = apply_instance_extras(model, cls, using_schema, database=database, table=table)
model = apply_instance_extras(
model,
cls,
using_schema,
database=database,
table=tables_and_models[
cls.meta.tablename
if using_schema is None
else f"{using_schema}.{cls.meta.tablename}"
][0],
)

# Handle prefetch related fields.
await cls.__handle_prefetch_related(
Expand Down
73 changes: 50 additions & 23 deletions edgy/core/db/querysets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,20 +186,22 @@ def _build_select_distinct(self, distinct_on: Optional[Sequence[str]], expressio

def _build_tables_select_from_relationship(
self,
) -> tuple[str, dict[str, tuple[type["BaseModelType"], "sqlalchemy.Table"], Any]]:
) -> tuple[str, dict[str, tuple["sqlalchemy.Table", type["BaseModelType"]]], Any]:
"""
Builds the tables relationships and joins.
When a table contains more than one foreign key pointing to the same
destination table, a lookup for the related field is made to understand
from which foreign key the table is looked up from.
"""
queryset: QuerySet = self
queryset: BaseQuerySet = self

select_from = queryset.table
maintablekey = select_from.key
tables_and_models = {select_from.key: (select_from, self.model_class)}
transitions: dict[(str, str), Any] = {}
transitions_is_full_outer: dict[(str, str), bool] = {}
tables_and_models: dict[str, tuple[sqlalchemy.Table, type[BaseModelType]]] = {
select_from.key: (select_from, self.model_class)
}
transitions: dict[tuple[str, str], Any] = {}
transitions_is_full_outer: dict[tuple[str, str], bool] = {}

# Select related
for select_path in queryset._select_related:
Expand Down Expand Up @@ -311,7 +313,9 @@ def _validate_only_and_defer(self) -> None:
if self._only and self._defer:
raise QuerySetError("You cannot use .only() and .defer() at the same time.")

async def as_select(self) -> Any:
async def as_select_with_tables(
self,
) -> tuple[Any, dict[str, tuple["sqlalchemy.Table", type["BaseModelType"]]]]:
"""
Builds the query select based on the given parameters and filters.
"""
Expand Down Expand Up @@ -379,7 +383,12 @@ async def as_select(self) -> Any:
expression = queryset._build_select_distinct(
queryset.distinct_on, expression=expression
)
return expression
return expression, tables_and_models

async def as_select(
self,
) -> Any:
return (await self.as_select_with_tables())[0]

def _kwargs_to_clauses(
self,
Expand Down Expand Up @@ -495,16 +504,21 @@ async def _embed_parent_in_result(
return cast(EdgyModel, result), new_result

async def _get_or_cache_row(
self, row: Any, extra_attr: str = "", raw: bool = False
self,
row: Any,
tables_and_models: dict[str, tuple["sqlalchemy.Table", type["BaseModelType"]]],
extra_attr: str = "",
raw: bool = False,
) -> tuple[EdgyModel, EdgyModel]:
is_only_fields = bool(self._only)
is_defer_fields = bool(self._defer)
raw_result, result = (
await self._cache.aget_or_cache_many(
self.model_class,
[row],
cache_fn=lambda row: self.model_class.from_sqla_row(
row,
cache_fn=lambda _row: self.model_class.from_sqla_row(
_row,
tables_and_models=tables_and_models,
select_related=self._select_related,
is_only_fields=is_only_fields,
only_fields=self._only,
Expand All @@ -513,7 +527,6 @@ async def _get_or_cache_row(
exclude_secrets=self._exclude_secrets,
using_schema=self.active_schema,
database=self.database,
table=self.table,
),
transform_fn=self._embed_parent_in_result,
)
Expand Down Expand Up @@ -578,7 +591,10 @@ def _clear_cache(self, keep_result_cache: bool = False) -> None:
self._cache_current_row: Optional[sqlalchemy.Row] = None

async def _handle_batch(
self, batch: Sequence[sqlalchemy.Row], queryset: "BaseQuerySet"
self,
batch: Sequence[sqlalchemy.Row],
tables_and_models: dict[str, tuple["sqlalchemy.Table", type["BaseModelType"]]],
queryset: "BaseQuerySet",
) -> Sequence[tuple[BaseModelType, BaseModelType]]:
is_only_fields = bool(queryset._only)
is_defer_fields = bool(queryset._defer)
Expand Down Expand Up @@ -638,6 +654,7 @@ async def _handle_batch(
batch,
cache_fn=lambda row: self.model_class.from_sqla_row(
row,
tables_and_models=tables_and_models,
select_related=self._select_related,
is_only_fields=is_only_fields,
only_fields=self._only,
Expand All @@ -646,7 +663,6 @@ async def _handle_batch(
exclude_secrets=self._exclude_secrets,
using_schema=self.active_schema,
database=self.database,
table=self.table,
),
transform_fn=self._embed_parent_in_result,
),
Expand All @@ -669,7 +685,7 @@ async def _execute_iterate(
# activates distinct, not distinct on
queryset = queryset.distinct() # type: ignore

expression = await queryset.as_select()
expression, tables_and_models = await queryset.as_select_with_tables()

if not fetch_all_at_once and bool(queryset.database.force_rollback):
# force_rollback on db = we have only one connection
Expand All @@ -691,7 +707,9 @@ async def _execute_iterate(
if fetch_all_at_once:
async with queryset.database as database:
batch = cast(Sequence[sqlalchemy.Row], await database.fetch_all(expression))
for row_num, result in enumerate(await self._handle_batch(batch, queryset)):
for row_num, result in enumerate(
await self._handle_batch(batch, tables_and_models, queryset)
):
if counter == 0:
self._cache_first = result
last_element = result
Expand All @@ -708,7 +726,9 @@ async def _execute_iterate(
# clear only result cache
self._cache.clear()
self._cache_fetch_all = False
for row_num, result in enumerate(await self._handle_batch(batch, queryset)):
for row_num, result in enumerate(
await self._handle_batch(batch, tables_and_models, queryset)
):
if counter == 0:
self._cache_first = result
last_element = result
Expand Down Expand Up @@ -861,8 +881,7 @@ async def _get_raw(self, **kwargs: Any) -> tuple[BaseModelType, Any]:
return await filter_query._get_raw()

queryset: BaseQuerySet = self

expression = (await queryset.as_select()).limit(2)
expression, tables_and_models = await queryset.limit(2).as_select_with_tables()
check_db_connection(queryset.database)
async with queryset.database as database:
rows = await database.fetch_all(expression)
Expand All @@ -874,7 +893,9 @@ async def _get_raw(self, **kwargs: Any) -> tuple[BaseModelType, Any]:
raise MultipleObjectsReturned()
queryset._cache_count = 1

return await queryset._get_or_cache_row(rows[0], "_cache_first,_cache_last")
return await queryset._get_or_cache_row(
rows[0], tables_and_models, "_cache_first,_cache_last"
)


class QuerySet(BaseQuerySet):
Expand Down Expand Up @@ -1248,11 +1269,14 @@ async def first(self) -> Union[EdgyEmbedTarget, None]:
queryset = self
if not queryset._order_by:
queryset = queryset.order_by(*self.model_class.pkcolumns)
expression, tables_and_models = await queryset.as_select_with_tables()
check_db_connection(queryset.database)
async with queryset.database as database:
row = await database.fetch_one(await queryset.as_select(), pos=0)
row = await database.fetch_one(expression, pos=0)
if row:
return (await self._get_or_cache_row(row, extra_attr="_cache_first"))[1]
return (
await self._get_or_cache_row(row, tables_and_models, extra_attr="_cache_first")
)[1]
return None

async def last(self) -> Union[EdgyEmbedTarget, None]:
Expand All @@ -1266,11 +1290,14 @@ async def last(self) -> Union[EdgyEmbedTarget, None]:
queryset = self
if not queryset._order_by:
queryset = queryset.order_by(*self.model_class.pkcolumns)
expression, tables_and_models = await queryset.reverse().as_select_with_tables()
check_db_connection(queryset.database)
async with queryset.database as database:
row = await database.fetch_one(await queryset.reverse().as_select(), pos=0)
row = await database.fetch_one(expression, pos=0)
if row:
return (await self._get_or_cache_row(row, extra_attr="_cache_last"))[1]
return (
await self._get_or_cache_row(row, tables_and_models, extra_attr="_cache_last")
)[1]
return None

async def create(self, *args: Any, **kwargs: Any) -> EdgyEmbedTarget:
Expand Down

0 comments on commit e1f9418

Please sign in to comment.