diff --git a/faststream/types.py b/faststream/types.py index 6404614db7..ae34858025 100644 --- a/faststream/types.py +++ b/faststream/types.py @@ -114,10 +114,7 @@ def __repr__(self) -> str: return "EMPTY" def __eq__(self, other: object) -> bool: - if not isinstance(other, _EmptyPlaceholder): - return NotImplemented - - return True + return isinstance(other, _EmptyPlaceholder) EMPTY: Any = _EmptyPlaceholder() diff --git a/faststream/utils/context/types.py b/faststream/utils/context/types.py index b176132217..5ca17d7ff3 100644 --- a/faststream/utils/context/types.py +++ b/faststream/utils/context/types.py @@ -60,13 +60,13 @@ def use(self, /, **kwargs: Any) -> AnyDict: """ name = f"{self.prefix}{self.name or self.param_name}" - if ( + if EMPTY != ( # noqa: SIM300 v := resolve_context_by_name( name=name, default=self.default, initial=self.initial, ) - ) != EMPTY: + ): kwargs[self.param_name] = v else: @@ -86,7 +86,7 @@ def resolve_context_by_name( value = context.resolve(name) except (KeyError, AttributeError): - if default != EMPTY: + if EMPTY != default: # noqa: SIM300 value = default elif initial is not None: diff --git a/tests/utils/context/test_main.py b/tests/utils/context/test_main.py index 2a7ca6d093..39e6434cec 100644 --- a/tests/utils/context/test_main.py +++ b/tests/utils/context/test_main.py @@ -174,3 +174,38 @@ def use( assert use(1) == [1] assert use(2) == [1, 2] + + +@pytest.mark.asyncio +async def test_context_with_custom_object_implementing_comparison(context: ContextRepo): + class User: + def __init__(self, user_id: int): + self.user_id = user_id + + def __eq__(self, other): + if not isinstance(other, User): + return NotImplemented + return self.user_id == other.user_id + + def __ne__(self, other): + return not self.__eq__(other) + + user2 = User(user_id=2) + user3 = User(user_id=3) + + @apply_types + async def use( + key1=Context("user1"), + key2=Context("user2", default=user2), + key3=Context("user3", default=user3), + ): + return ( + key1 == User(user_id=1) + and key2 == User(user_id=2) + and key3 == User(user_id=4) + ) + + with context.scope("user1", User(user_id=1)), context.scope( + "user3", User(user_id=4) + ): + assert await use()