diff --git a/faststream/asyncapi/generate.py b/faststream/asyncapi/generate.py index ed9b45978e..aeb2daaeed 100644 --- a/faststream/asyncapi/generate.py +++ b/faststream/asyncapi/generate.py @@ -57,7 +57,6 @@ def get_app_schema(app: Union["FastStream", "StreamRouter[Any]"]) -> Schema: payloads, messages, ) - schema = Schema( info=Info( title=app.title, @@ -146,9 +145,13 @@ def _resolve_msg_payloads( payloads: Dict[str, Any], messages: Dict[str, Any], ) -> Reference: - one_of_list: List[Reference] = [] + """Replace message payload by reference and normalize payloads. + Payloads and messages are editable dicts to store schemas for reference in AsyncAPI. + """ + one_of_list: List[Reference] = [] m.payload = _move_pydantic_refs(m.payload, DEF_KEY) + if DEF_KEY in m.payload: payloads.update(m.payload.pop(DEF_KEY)) @@ -186,6 +189,7 @@ def _move_pydantic_refs( original: Any, key: str, ) -> Any: + """Remove pydantic references and replacem them by real schemas.""" if not isinstance(original, Dict): return original diff --git a/faststream/broker/subscriber/usecase.py b/faststream/broker/subscriber/usecase.py index ffa2809cac..661524e8b1 100644 --- a/faststream/broker/subscriber/usecase.py +++ b/faststream/broker/subscriber/usecase.py @@ -407,7 +407,9 @@ def get_log_context( @property def call_name(self) -> str: """Returns the name of the handler call.""" - # TODO: default call_name + if not self.calls: + return "Subscriber" + return to_camelcase(self.calls[0].call_name) def get_description(self) -> Optional[str]: @@ -433,4 +435,14 @@ def get_payloads(self) -> List[Tuple["AnyDict", str]]: payloads.append((body, to_camelcase(h.call_name))) + if not self.calls: + payloads.append( + ( + { + "title": f"{self.title_ or self.call_name}:Message:Payload", + }, + to_camelcase(self.call_name), + ) + ) + return payloads diff --git a/faststream/confluent/subscriber/asyncapi.py b/faststream/confluent/subscriber/asyncapi.py index 7ec3ffb965..ca56bec6d6 100644 --- a/faststream/confluent/subscriber/asyncapi.py +++ b/faststream/confluent/subscriber/asyncapi.py @@ -34,7 +34,6 @@ def get_schema(self) -> Dict[str, Channel]: channels = {} payloads = self.get_payloads() - for t in self.topics: handler_name = self.title_ or f"{t}:{self.call_name}" diff --git a/tests/asyncapi/base/naming.py b/tests/asyncapi/base/naming.py index 798a24b564..0c3fc9454c 100644 --- a/tests/asyncapi/base/naming.py +++ b/tests/asyncapi/base/naming.py @@ -90,6 +90,76 @@ async def handle_user_created(msg: str): ... "custom:Message:Payload" ] + def test_subscriber_naming_default(self): + broker = self.broker_class() + + broker.subscriber("test") + + schema = get_app_schema(FastStream(broker)).to_jsonable() + + assert list(schema["channels"].keys()) == [ + IsStr(regex=r"test[\w:]*:Subscriber") + ] + + assert list(schema["components"]["messages"].keys()) == [ + IsStr(regex=r"test[\w:]*:Subscriber:Message") + ] + + for key, v in schema["components"]["schemas"].items(): + assert key == "Subscriber:Message:Payload" + assert v == {"title": key} + + def test_subscriber_naming_default_with_title(self): + broker = self.broker_class() + + broker.subscriber("test", title="custom") + + schema = get_app_schema(FastStream(broker)).to_jsonable() + + assert list(schema["channels"].keys()) == ["custom"] + + assert list(schema["components"]["messages"].keys()) == ["custom:Message"] + + assert list(schema["components"]["schemas"].keys()) == [ + "custom:Message:Payload" + ] + + assert schema["components"]["schemas"]["custom:Message:Payload"] == { + "title": "custom:Message:Payload" + } + + def test_multi_subscribers_naming_default(self): + broker = self.broker_class() + + @broker.subscriber("test") + async def handle_user_created(msg: str): ... + + broker.subscriber("test2") + broker.subscriber("test3") + + schema = get_app_schema(FastStream(broker)).to_jsonable() + + assert list(schema["channels"].keys()) == [ + IsStr(regex=r"test[\w:]*:HandleUserCreated"), + IsStr(regex=r"test2[\w:]*:Subscriber"), + IsStr(regex=r"test3[\w:]*:Subscriber"), + ] + + assert list(schema["components"]["messages"].keys()) == [ + IsStr(regex=r"test[\w:]*:HandleUserCreated:Message"), + IsStr(regex=r"test2[\w:]*:Subscriber:Message"), + IsStr(regex=r"test3[\w:]*:Subscriber:Message"), + ] + + assert list(schema["components"]["schemas"].keys()) == [ + "HandleUserCreated:Message:Payload", + "Subscriber:Message:Payload", + ] + + assert schema["components"]["schemas"]["Subscriber:Message:Payload"] == { + "title": "Subscriber:Message:Payload" + } + class FilterNaming(BaseNaming): def test_subscriber_filter_base(self):