Skip to content

Commit

Permalink
Merge branch 'main' into 0.6.0
Browse files Browse the repository at this point in the history
  • Loading branch information
KrySeyt committed Sep 12, 2024
2 parents 4842baa + 6a06a38 commit 91405d9
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 36 deletions.
15 changes: 14 additions & 1 deletion faststream/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,20 @@

import warnings

from faststream.cli.main import cli
try:
from faststream.cli.main import cli
except ImportError:
has_typer = False
else:
has_typer = True

if not has_typer:
raise ImportError(
"\n\nYou're trying to use the FastStream CLI, "
"\nbut you haven't installed the required dependencies."
"\nPlease install them using the following command: "
'\npip install "faststream[cli]"'
)

warnings.filterwarnings("default", category=ImportWarning, module="faststream")

Expand Down
8 changes: 7 additions & 1 deletion faststream/specification/asyncapi/v2_6_0/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,21 +173,26 @@ def _resolve_msg_payloads(
one_of = m.payload.get("oneOf")
if isinstance(one_of, dict):
for p_title, p in one_of.items():
p_title = p_title.replace("/", ".")
payloads.update(p.pop(DEF_KEY, {}))
if p_title not in payloads:
payloads[p_title] = p
one_of_list.append(Reference(**{"$ref": f"#/components/schemas/{p_title}"}))

elif one_of is not None:
# Descriminator case
for p in one_of:
p_title = next(iter(p.values())).split("/")[-1]
p_value = next(iter(p.values()))
p_title = p_value.split("/")[-1]
p_title = p_title.replace("/", ".")
if p_title not in payloads:
payloads[p_title] = p
one_of_list.append(Reference(**{"$ref": f"#/components/schemas/{p_title}"}))

if not one_of_list:
payloads.update(m.payload.pop(DEF_KEY, {}))
p_title = m.payload.get("title", f"{channel_name}Payload")
p_title = p_title.replace("/", ".")
if p_title not in payloads:
payloads[p_title] = m.payload
m.payload = {"$ref": f"#/components/schemas/{p_title}"}
Expand All @@ -196,6 +201,7 @@ def _resolve_msg_payloads(
m.payload["oneOf"] = one_of_list

assert m.title # nosec B101
m.title = m.title.replace("/", ".")
messages[m.title] = m
return Reference(**{"$ref": f"#/components/messages/{m.title}"})

Expand Down
64 changes: 30 additions & 34 deletions tests/asyncapi/base/v2_6_0/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,36 @@ async def handle(user: Model): ...
},
}, schema["components"]

def test_with_filter(self):
class User(pydantic.BaseModel):
name: str = ""
id: int

broker = self.broker_class()

sub = broker.subscriber("test/one")

@sub(
filter=lambda m: m.content_type == "application/json",
)
async def handle(id: int): ...

@sub
async def handle_default(msg): ...

schema = get_app_schema(self.build_app(broker), version="2.6.0").to_jsonable()

name, message = next(iter(schema["components"]["messages"].items()))

assert name == IsStr(regex=r"test.one[\w:]*:Handle:Message"), name

assert len(message["payload"]["oneOf"]) == 2

payload = schema["components"]["schemas"]

assert "Handle:Message:Payload" in list(payload.keys())
assert "HandleDefault:Message:Payload" in list(payload.keys())


class ArgumentsTestcase(FastAPICompatible):
dependency_builder = staticmethod(Depends)
Expand Down Expand Up @@ -616,37 +646,3 @@ async def handle(id: int, user: Optional[str] = None, message=Context()): ...
"type": "object",
}
)

def test_with_filter(self):
# TODO: move it to FastAPICompatible with FastAPI refactore
class User(pydantic.BaseModel):
name: str = ""
id: int

broker = self.broker_class()

sub = broker.subscriber("test")

@sub(
filter=lambda m: m.content_type == "application/json",
)
async def handle(id: int): ...

@sub
async def handle_default(msg): ...

schema = get_app_schema(self.build_app(broker), version="2.6.0").to_jsonable()

assert (
len(
next(iter(schema["components"]["messages"].values()))["payload"][
"oneOf"
]
)
== 2
)

payload = schema["components"]["schemas"]

assert "Handle:Message:Payload" in list(payload.keys())
assert "HandleDefault:Message:Payload" in list(payload.keys())

0 comments on commit 91405d9

Please sign in to comment.