From eb1faa6af379844f3ee8335fe85ffad82dafc15d Mon Sep 17 00:00:00 2001 From: bilal02arbisoft Date: Mon, 23 Sep 2024 01:40:36 +0500 Subject: [PATCH] Correct query prefixing for embedded models to prevent field concatenation --- aredis_om/model/model.py | 41 +++++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 19 deletions(-) diff --git a/aredis_om/model/model.py b/aredis_om/model/model.py index 27ebcc5..c6fa624 100644 --- a/aredis_om/model/model.py +++ b/aredis_om/model/model.py @@ -309,7 +309,7 @@ def score_field(self) -> str: class ExpressionProxy: def __init__(self, field: ModelField, parents: List[Tuple[str, "RedisModel"]]): self.field = field - self.parents = parents + self.parents = parents.copy() # Ensure a copy is stored def __eq__(self, other: Any) -> Expression: # type: ignore[override] return Expression( @@ -387,13 +387,14 @@ def __getattr__(self, item): attr = getattr(embedded_cls, item) else: attr = getattr(outer_type, item) + if isinstance(attr, self.__class__): + # Clone the parents to ensure isolation + new_parents = self.parents.copy() new_parent = (self.field.alias, outer_type) - if new_parent not in attr.parents: - attr.parents.append(new_parent) - new_parents = list(set(self.parents) - set(attr.parents)) - if new_parents: - attr.parents = new_parents + attr.parents + if new_parent not in new_parents: + new_parents.append(new_parent) + attr.parents = new_parents return attr @@ -624,18 +625,19 @@ def expand_tag_value(value): @classmethod def resolve_value( - cls, - field_name: str, - field_type: RediSearchFieldTypes, - field_info: PydanticFieldInfo, - op: Operators, - value: Any, - parents: List[Tuple[str, "RedisModel"]], + cls, + field_name: str, + field_type: RediSearchFieldTypes, + field_info: PydanticFieldInfo, + op: Operators, + value: Any, + parents: List[Tuple[str, "RedisModel"]], ) -> str: + # The 'field_name' should already include the correct prefix + result = "" if parents: prefix = "_".join([p[0] for p in parents]) field_name = f"{prefix}_{field_name}" - result = "" if field_type is RediSearchFieldTypes.TEXT: result = f"@{field_name}_fts:" if op is Operators.EQ: @@ -792,15 +794,13 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str: if expression.op is Operators.ALL: if encompassing_expression_is_negated: - # TODO: Is there a use case for this, perhaps for dynamic - # scoring purposes with full-text search? raise QueryNotSupportedError( "You cannot negate a query for all results." ) return "*" if isinstance(expression.left, Expression) or isinstance( - expression.left, NegatedExpression + expression.left, NegatedExpression ): result += f"({cls.resolve_redisearch_query(expression.left)})" elif isinstance(expression.left, ModelField): @@ -827,6 +827,11 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str: f"or an expression enclosed in parentheses. Docs: {ERRORS_URL}#E7" ) + if isinstance(expression.left, ModelField) and expression.parents: + # Build field_name using the specific parents for this expression + prefix = "_".join([p[0] for p in expression.parents]) + field_name = f"{prefix}_{field_name}" + right = expression.right if isinstance(right, Expression) or isinstance(right, NegatedExpression): @@ -842,8 +847,6 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str: if isinstance(right, NegatedExpression): result += "-" - # We're handling the RediSearch operator in this call ("-"), so resolve the - # inner expression instead of the NegatedExpression. right = right.expression result += f"({cls.resolve_redisearch_query(right)})"