Skip to content

Commit

Permalink
fix: correct dependency injection of custom context fields implementi…
Browse files Browse the repository at this point in the history
…ng partial __eq__/__ne__ (#1809)

* fix: correct dependency injection of custom context fields where the custom field value implements __eq__/__ne__ returning NotImplemented

* chore: fix precommit

* chore: fix secrets

* chore: fix secrets

* chore: refactored unittest for comparison of EMPTY with custom objects

---------

Co-authored-by: ahumbert <[email protected]>
Co-authored-by: Nikita Pastukhov <[email protected]>
  • Loading branch information
3 people authored Sep 25, 2024
1 parent bd954cb commit af92337
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 7 deletions.
5 changes: 1 addition & 4 deletions faststream/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,7 @@ def __repr__(self) -> str:
return "EMPTY"

def __eq__(self, other: object) -> bool:
if not isinstance(other, _EmptyPlaceholder):
return NotImplemented

return True
return isinstance(other, _EmptyPlaceholder)


EMPTY: Any = _EmptyPlaceholder()
6 changes: 3 additions & 3 deletions faststream/utils/context/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,13 @@ def use(self, /, **kwargs: Any) -> AnyDict:
"""
name = f"{self.prefix}{self.name or self.param_name}"

if (
if EMPTY != ( # noqa: SIM300
v := resolve_context_by_name(
name=name,
default=self.default,
initial=self.initial,
)
) != EMPTY:
):
kwargs[self.param_name] = v

else:
Expand All @@ -86,7 +86,7 @@ def resolve_context_by_name(
value = context.resolve(name)

except (KeyError, AttributeError):
if default != EMPTY:
if EMPTY != default: # noqa: SIM300
value = default

elif initial is not None:
Expand Down
35 changes: 35 additions & 0 deletions tests/utils/context/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,38 @@ def use(

assert use(1) == [1]
assert use(2) == [1, 2]


@pytest.mark.asyncio
async def test_context_with_custom_object_implementing_comparison(context: ContextRepo):
class User:
def __init__(self, user_id: int):
self.user_id = user_id

def __eq__(self, other):
if not isinstance(other, User):
return NotImplemented
return self.user_id == other.user_id

def __ne__(self, other):
return not self.__eq__(other)

user2 = User(user_id=2)
user3 = User(user_id=3)

@apply_types
async def use(
key1=Context("user1"),
key2=Context("user2", default=user2),
key3=Context("user3", default=user3),
):
return (
key1 == User(user_id=1)
and key2 == User(user_id=2)
and key3 == User(user_id=4)
)

with context.scope("user1", User(user_id=1)), context.scope(
"user3", User(user_id=4)
):
assert await use()

0 comments on commit af92337

Please sign in to comment.