diff --git a/docs/docs/en/getting-started/asgi.md b/docs/docs/en/getting-started/asgi.md index 9006c37e23..6ea640d3c5 100644 --- a/docs/docs/en/getting-started/asgi.md +++ b/docs/docs/en/getting-started/asgi.md @@ -114,6 +114,37 @@ app = AsgiFastStream( Now, your **AsyncAPI HTML** representation can be found by the `/docs` url. +### FastStream object reusage + +You may also use regular `FastStream` application object for similar result + +```python linenums="1" hl_lines="2 9" +from faststream import FastStream +from faststream.nats import NatsBroker +from faststream.asgi import make_ping_asgi, AsgiResponse + +broker = NatsBroker() + +async def liveness_ping(scope, receive, send): + return AsgiResponse(b"", status_code=200) + + +app = FastStream(broker).as_asgi( + asgi_routes=[ + ("/liveness", liveness_ping), + ("/readiness", make_ping_asgi(broker, timeout=5.0)), + ], + asyncapi_path="/docs", +) +``` + +``` tip + For app which use ASGI you may use cli command like for default FastStream app + + ```shell + faststream run main:app --host 0.0.0.0 --port 1337 --workers 4 + ``` + ## Other ASGI Compatibility Moreover, our wrappers can be used as ready-to-use endpoins for other **ASGI** frameworks. This can be very helpful When you are running **FastStream** in the same runtime as any other **ASGI** frameworks. diff --git a/docs/docs/en/getting-started/middlewares/exception.md b/docs/docs/en/getting-started/middlewares/exception.md index c9e7e3e69c..e50cbf3bdd 100644 --- a/docs/docs/en/getting-started/middlewares/exception.md +++ b/docs/docs/en/getting-started/middlewares/exception.md @@ -15,7 +15,7 @@ Sometimes, you need to register exception processors at the top level of your ap For this purpose, **FastStream** provides a special `ExceptionMiddleware`. You just need to create it, register handlers, and add it to the broker, router, or subscribers you want (as a [regular middleware](index.md){.internal-link}). ```python linenums="1" -from faststream. import ExceptionMiddleware +from faststream import ExceptionMiddleware exception_middleware = ExceptionMiddleware() diff --git a/faststream/_internal/__init__.py b/faststream/_internal/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/faststream/_internal/application.py b/faststream/_internal/application.py new file mode 100644 index 0000000000..dd0140db4d --- /dev/null +++ b/faststream/_internal/application.py @@ -0,0 +1,207 @@ +import logging +from abc import ABC, abstractmethod +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Optional, + Sequence, + TypeVar, + Union, +) + +import anyio +from typing_extensions import ParamSpec + +from faststream.asyncapi.proto import AsyncAPIApplication +from faststream.log.logging import logger +from faststream.utils import apply_types, context +from faststream.utils.functions import drop_response_type, fake_context, to_async + +P_HookParams = ParamSpec("P_HookParams") +T_HookReturn = TypeVar("T_HookReturn") + + +if TYPE_CHECKING: + from faststream.asyncapi.schema import ( + Contact, + ContactDict, + ExternalDocs, + ExternalDocsDict, + License, + LicenseDict, + Tag, + TagDict, + ) + from faststream.broker.core.usecase import BrokerUsecase + from faststream.types import ( + AnyDict, + AnyHttpUrl, + AsyncFunc, + Lifespan, + LoggerProto, + SettingField, + ) + + +class Application(ABC, AsyncAPIApplication): + def __init__( + self, + broker: Optional["BrokerUsecase[Any, Any]"] = None, + logger: Optional["LoggerProto"] = logger, + lifespan: Optional["Lifespan"] = None, + # AsyncAPI args, + title: str = "FastStream", + version: str = "0.1.0", + description: str = "", + terms_of_service: Optional["AnyHttpUrl"] = None, + license: Optional[Union["License", "LicenseDict", "AnyDict"]] = None, + contact: Optional[Union["Contact", "ContactDict", "AnyDict"]] = None, + tags: Optional[Sequence[Union["Tag", "TagDict", "AnyDict"]]] = None, + external_docs: Optional[ + Union["ExternalDocs", "ExternalDocsDict", "AnyDict"] + ] = None, + identifier: Optional[str] = None, + on_startup: Sequence[Callable[P_HookParams, T_HookReturn]] = (), + after_startup: Sequence[Callable[P_HookParams, T_HookReturn]] = (), + on_shutdown: Sequence[Callable[P_HookParams, T_HookReturn]] = (), + after_shutdown: Sequence[Callable[P_HookParams, T_HookReturn]] = (), + ) -> None: + context.set_global("app", self) + + self._should_exit = anyio.Event() + self.broker = broker + self.logger = logger + self.context = context + + self._on_startup_calling: List[AsyncFunc] = [ + apply_types(to_async(x)) for x in on_startup + ] + self._after_startup_calling: List[AsyncFunc] = [ + apply_types(to_async(x)) for x in after_startup + ] + self._on_shutdown_calling: List[AsyncFunc] = [ + apply_types(to_async(x)) for x in on_shutdown + ] + self._after_shutdown_calling: List[AsyncFunc] = [ + apply_types(to_async(x)) for x in after_shutdown + ] + + if lifespan is not None: + self.lifespan_context = apply_types( + func=lifespan, wrap_model=drop_response_type + ) + else: + self.lifespan_context = fake_context + + # AsyncAPI information + self.title = title + self.version = version + self.description = description + self.terms_of_service = terms_of_service + self.license = license + self.contact = contact + self.identifier = identifier + self.asyncapi_tags = tags + self.external_docs = external_docs + + @abstractmethod + async def run( + self, + log_level: int, + run_extra_options: Optional[Dict[str, "SettingField"]] = None, + sleep_time: float = 0.1, + ) -> None: ... + + def set_broker(self, broker: "BrokerUsecase[Any, Any]") -> None: + """Set already existed App object broker. + + Useful then you create/init broker in `on_startup` hook. + """ + self.broker = broker + + def on_startup( + self, + func: Callable[P_HookParams, T_HookReturn], + ) -> Callable[P_HookParams, T_HookReturn]: + """Add hook running BEFORE broker connected. + + This hook also takes an extra CLI options as a kwargs. + """ + self._on_startup_calling.append(apply_types(to_async(func))) + return func + + def on_shutdown( + self, + func: Callable[P_HookParams, T_HookReturn], + ) -> Callable[P_HookParams, T_HookReturn]: + """Add hook running BEFORE broker disconnected.""" + self._on_shutdown_calling.append(apply_types(to_async(func))) + return func + + def after_startup( + self, + func: Callable[P_HookParams, T_HookReturn], + ) -> Callable[P_HookParams, T_HookReturn]: + """Add hook running AFTER broker connected.""" + self._after_startup_calling.append(apply_types(to_async(func))) + return func + + def after_shutdown( + self, + func: Callable[P_HookParams, T_HookReturn], + ) -> Callable[P_HookParams, T_HookReturn]: + """Add hook running AFTER broker disconnected.""" + self._after_shutdown_calling.append(apply_types(to_async(func))) + return func + + def exit(self) -> None: + """Stop application manually.""" + self._should_exit.set() + + async def start( + self, + **run_extra_options: "SettingField", + ) -> None: + """Executes startup hooks and start broker.""" + for func in self._on_startup_calling: + await func(**run_extra_options) + + if self.broker is not None: + await self.broker.start() + + for func in self._after_startup_calling: + await func() + + async def stop(self) -> None: + """Executes shutdown hooks and stop broker.""" + for func in self._on_shutdown_calling: + await func() + + if self.broker is not None: + await self.broker.close() + + for func in self._after_shutdown_calling: + await func() + + async def _startup( + self, + log_level: int = logging.INFO, + run_extra_options: Optional[Dict[str, "SettingField"]] = None, + ) -> None: + self._log(log_level, "FastStream app starting...") + await self.start(**(run_extra_options or {})) + self._log( + log_level, "FastStream app started successfully! To exit, press CTRL+C" + ) + + async def _shutdown(self, log_level: int = logging.INFO) -> None: + self._log(log_level, "FastStream app shutting down...") + await self.stop() + self._log(log_level, "FastStream app shut down gracefully.") + + def _log(self, level: int, message: str) -> None: + if self.logger is not None: + self.logger.log(level, message) diff --git a/faststream/app.py b/faststream/app.py index a33baa3fee..3f6c2d546f 100644 --- a/faststream/app.py +++ b/faststream/app.py @@ -1,162 +1,35 @@ -import logging.config +import logging from typing import ( TYPE_CHECKING, - Any, AsyncIterator, - Callable, Dict, - List, Optional, Sequence, + Tuple, TypeVar, - Union, ) import anyio from typing_extensions import Annotated, ParamSpec, deprecated from faststream._compat import ExceptionGroup -from faststream.asyncapi.proto import AsyncAPIApplication +from faststream._internal.application import Application +from faststream.asgi.app import AsgiFastStream from faststream.cli.supervisors.utils import set_exit from faststream.exceptions import ValidationError -from faststream.log.logging import logger -from faststream.utils import apply_types, context -from faststream.utils.functions import drop_response_type, fake_context, to_async +from faststream.utils.functions import fake_context P_HookParams = ParamSpec("P_HookParams") T_HookReturn = TypeVar("T_HookReturn") if TYPE_CHECKING: - from faststream.asyncapi.schema import ( - Contact, - ContactDict, - ExternalDocs, - ExternalDocsDict, - License, - LicenseDict, - Tag, - TagDict, - ) - from faststream.broker.core.usecase import BrokerUsecase - from faststream.types import ( - AnyCallable, - AnyDict, - AnyHttpUrl, - AsyncFunc, - Lifespan, - LoggerProto, - SettingField, - ) - - -class FastStream(AsyncAPIApplication): - """A class representing a FastStream application.""" - - _on_startup_calling: List["AsyncFunc"] - _after_startup_calling: List["AsyncFunc"] - _on_shutdown_calling: List["AsyncFunc"] - _after_shutdown_calling: List["AsyncFunc"] - - def __init__( - self, - broker: Optional["BrokerUsecase[Any, Any]"] = None, - logger: Optional["LoggerProto"] = logger, - lifespan: Optional["Lifespan"] = None, - # AsyncAPI args, - title: str = "FastStream", - version: str = "0.1.0", - description: str = "", - terms_of_service: Optional["AnyHttpUrl"] = None, - license: Optional[Union["License", "LicenseDict", "AnyDict"]] = None, - contact: Optional[Union["Contact", "ContactDict", "AnyDict"]] = None, - tags: Optional[Sequence[Union["Tag", "TagDict", "AnyDict"]]] = None, - external_docs: Optional[ - Union["ExternalDocs", "ExternalDocsDict", "AnyDict"] - ] = None, - identifier: Optional[str] = None, - on_startup: Sequence["AnyCallable"] = (), - after_startup: Sequence["AnyCallable"] = (), - on_shutdown: Sequence["AnyCallable"] = (), - after_shutdown: Sequence["AnyCallable"] = (), - # all options should be copied to AsgiFastStream class - ) -> None: - context.set_global("app", self) - - self.broker = broker - self.logger = logger - self.context = context - - self._on_startup_calling = [apply_types(to_async(x)) for x in on_startup] - self._after_startup_calling = [apply_types(to_async(x)) for x in after_startup] - self._on_shutdown_calling = [apply_types(to_async(x)) for x in on_shutdown] - self._after_shutdown_calling = [ - apply_types(to_async(x)) for x in after_shutdown - ] - - self.lifespan_context = ( - apply_types( - func=lifespan, - wrap_model=drop_response_type, - ) - if lifespan is not None - else fake_context - ) - - self._should_exit = anyio.Event() - - # AsyncAPI information - self.title = title - self.version = version - self.description = description - self.terms_of_service = terms_of_service - self.license = license - self.contact = contact - self.identifier = identifier - self.asyncapi_tags = tags - self.external_docs = external_docs - - def set_broker(self, broker: "BrokerUsecase[Any, Any]") -> None: - """Set already existed App object broker. - - Useful then you create/init broker in `on_startup` hook. - """ - self.broker = broker - - def on_startup( - self, - func: Callable[P_HookParams, T_HookReturn], - ) -> Callable[P_HookParams, T_HookReturn]: - """Add hook running BEFORE broker connected. + from faststream.asgi.types import ASGIApp + from faststream.types import SettingField - This hook also takes an extra CLI options as a kwargs. - """ - self._on_startup_calling.append(apply_types(to_async(func))) - return func - def on_shutdown( - self, - func: Callable[P_HookParams, T_HookReturn], - ) -> Callable[P_HookParams, T_HookReturn]: - """Add hook running BEFORE broker disconnected.""" - self._on_shutdown_calling.append(apply_types(to_async(func))) - return func - - def after_startup( - self, - func: Callable[P_HookParams, T_HookReturn], - ) -> Callable[P_HookParams, T_HookReturn]: - """Add hook running AFTER broker connected.""" - self._after_startup_calling.append(apply_types(to_async(func))) - return func - - def after_shutdown( - self, - func: Callable[P_HookParams, T_HookReturn], - ) -> Callable[P_HookParams, T_HookReturn]: - """Add hook running AFTER broker disconnected.""" - self._after_shutdown_calling.append(apply_types(to_async(func))) - return func +class FastStream(Application): + """A class representing a FastStream application.""" async def run( self, @@ -188,54 +61,12 @@ async def run( for ex in e.exceptions: raise ex from None - def exit(self) -> None: - """Stop application manually.""" - self._should_exit.set() - - async def start( - self, - **run_extra_options: "SettingField", - ) -> None: - """Executes startup hooks and start broker.""" - for func in self._on_startup_calling: - await func(**run_extra_options) - - if self.broker is not None: - await self.broker.start() - - for func in self._after_startup_calling: - await func() - - async def stop(self) -> None: - """Executes shutdown hooks and stop broker.""" - for func in self._on_shutdown_calling: - await func() - - if self.broker is not None: - await self.broker.close() - - for func in self._after_shutdown_calling: - await func() - - async def _startup( + def as_asgi( self, - log_level: int = logging.INFO, - run_extra_options: Optional[Dict[str, "SettingField"]] = None, - ) -> None: - self._log(log_level, "FastStream app starting...") - await self.start(**(run_extra_options or {})) - self._log( - log_level, "FastStream app started successfully! To exit, press CTRL+C" - ) - - async def _shutdown(self, log_level: int = logging.INFO) -> None: - self._log(log_level, "FastStream app shutting down...") - await self.stop() - self._log(log_level, "FastStream app shut down gracefully.") - - def _log(self, level: int, message: str) -> None: - if self.logger is not None: - self.logger.log(level, message) + asgi_routes: Sequence[Tuple[str, "ASGIApp"]] = (), + asyncapi_path: Optional[str] = None, + ) -> AsgiFastStream: + return AsgiFastStream.from_app(self, asgi_routes, asyncapi_path) try: diff --git a/faststream/asgi/app.py b/faststream/asgi/app.py index c91a4e0517..36685f23fe 100644 --- a/faststream/asgi/app.py +++ b/faststream/asgi/app.py @@ -1,9 +1,11 @@ +import logging import traceback from contextlib import asynccontextmanager from typing import ( TYPE_CHECKING, Any, AsyncIterator, + Dict, Optional, Sequence, Tuple, @@ -12,7 +14,7 @@ import anyio -from faststream.app import FastStream +from faststream._internal.application import Application from faststream.asgi.factories import make_asyncapi_asgi from faststream.asgi.response import AsgiResponse from faststream.asgi.websocket import WebSocketClose @@ -37,10 +39,11 @@ AnyHttpUrl, Lifespan, LoggerProto, + SettingField, ) -class AsgiFastStream(FastStream): +class AsgiFastStream(Application): def __init__( self, broker: Optional["BrokerUsecase[Any, Any]"] = None, @@ -90,6 +93,36 @@ def __init__( if asyncapi_path: self.mount(asyncapi_path, make_asyncapi_asgi(self)) + @classmethod + def from_app( + cls, + app: Application, + asgi_routes: Sequence[Tuple[str, "ASGIApp"]], + asyncapi_path: Optional[str] = None, + ) -> "AsgiFastStream": + asgi_app = cls( + app.broker, + asgi_routes=asgi_routes, + asyncapi_path=asyncapi_path, + logger=app.logger, + lifespan=None, + title=app.title, + version=app.version, + description=app.description, + terms_of_service=app.terms_of_service, + license=app.license, + contact=app.contact, + tags=app.asyncapi_tags, + external_docs=app.external_docs, + identifier=app.identifier, + ) + asgi_app.lifespan_context = app.lifespan_context + asgi_app._on_startup_calling = app._on_startup_calling + asgi_app._after_startup_calling = app._after_startup_calling + asgi_app._on_shutdown_calling = app._on_shutdown_calling + asgi_app._after_shutdown_calling = app._after_shutdown_calling + return asgi_app + def mount(self, path: str, route: "ASGIApp") -> None: self.routes.append((path, route)) @@ -107,6 +140,30 @@ async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> No await self.not_found(scope, receive, send) return + async def run( + self, + log_level: int = logging.INFO, + run_extra_options: Optional[Dict[str, "SettingField"]] = None, + sleep_time: float = 0.1, + ) -> None: + import uvicorn + + if not run_extra_options: + run_extra_options = {} + port = int(run_extra_options.pop("port", 8000)) # type: ignore[arg-type] + workers = int(run_extra_options.pop("workers", 1)) # type: ignore[arg-type] + host = str(run_extra_options.pop("host", "localhost")) + config = uvicorn.Config( + self, + host=host, + port=port, + log_level=log_level, + workers=workers, + **run_extra_options, + ) + server = uvicorn.Server(config) + await server.serve() + @asynccontextmanager async def start_lifespan_context(self) -> AsyncIterator[None]: async with anyio.create_task_group() as tg, self.lifespan_context(): @@ -141,7 +198,7 @@ async def lifespan(self, scope: "Scope", receive: "Receive", send: "Send") -> No await send({"type": "lifespan.shutdown.complete"}) async def not_found(self, scope: "Scope", receive: "Receive", send: "Send") -> None: - not_found_msg = "FastStream doesn't support regular HTTP protocol." + not_found_msg = "App doesn't support regular HTTP protocol." if scope["type"] == "websocket": websocket_close = WebSocketClose( diff --git a/faststream/broker/fastapi/route.py b/faststream/broker/fastapi/route.py index d8e39f96b2..03c983cae2 100644 --- a/faststream/broker/fastapi/route.py +++ b/faststream/broker/fastapi/route.py @@ -18,6 +18,7 @@ from starlette.requests import Request from faststream.broker.fastapi.get_dependant import get_fastapi_native_dependant +from faststream.broker.response import Response, ensure_response from faststream.broker.types import P_HandlerParams, T_HandlerReturn from ._compat import ( @@ -186,7 +187,7 @@ def make_fastapi_execution( response_model_exclude_none: bool, ) -> Callable[ ["StreamMessage", "NativeMessage[Any]"], - Awaitable[Any], + Awaitable[Response], ]: """Creates a FastAPI application.""" is_coroutine = asyncio.iscoroutinefunction(dependent.call) @@ -194,7 +195,7 @@ def make_fastapi_execution( async def app( request: "StreamMessage", raw_message: "NativeMessage[Any]", # to support BackgroundTasks by middleware - ) -> Any: + ) -> Response: """Consume StreamMessage and return user function result.""" async with AsyncExitStack() as stack: if FASTAPI_V106: @@ -215,14 +216,16 @@ async def app( if solved_result.errors: raise_fastapi_validation_error(solved_result.errors, request._body) # type: ignore[arg-type] - raw_reponse = await run_endpoint_function( + function_result = await run_endpoint_function( dependant=dependent, values=solved_result.values, is_coroutine=is_coroutine, ) - content = await serialize_response( - response_content=raw_reponse, + response = ensure_response(function_result) + + response.body = await serialize_response( + response_content=response.body, field=response_field, include=response_model_include, exclude=response_model_exclude, @@ -233,8 +236,8 @@ async def app( is_coroutine=is_coroutine, ) - return content + return response - return None + raise AssertionError("unreachable") return app diff --git a/faststream/cli/main.py b/faststream/cli/main.py index 900a36d810..36d321ab8d 100644 --- a/faststream/cli/main.py +++ b/faststream/cli/main.py @@ -11,6 +11,7 @@ from faststream import FastStream from faststream.__about__ import __version__ +from faststream._internal.application import Application from faststream.cli.docs.app import docs_app from faststream.cli.utils.imports import import_from_string from faststream.cli.utils.logs import LogLevels, get_log_level, set_log_level @@ -109,6 +110,7 @@ def run( app, extra = parse_cli_args(app, *ctx.args) casted_log_level = get_log_level(log_level) + module_path, app_obj = import_from_string(app) if app_dir: # pragma: no branch sys.path.insert(0, app_dir) @@ -126,8 +128,6 @@ def run( _run(*args) else: - module_path, _ = import_from_string(app) - if app_dir != ".": reload_dirs = [str(module_path), app_dir] else: @@ -142,11 +142,15 @@ def run( elif workers > 1: from faststream.cli.supervisors.multiprocess import Multiprocess - Multiprocess( - target=_run, - args=(*args, logging.DEBUG), - workers=workers, - ).run() + if isinstance(app_obj, FastStream): + Multiprocess( + target=_run, + args=(*args, logging.DEBUG), + workers=workers, + ).run() + else: + args[1]["workers"] = workers + _run(*args) else: _run(*args) @@ -165,9 +169,9 @@ def _run( if is_factory and callable(app_obj): app_obj = app_obj() - if not isinstance(app_obj, FastStream): + if not isinstance(app_obj, Application): raise typer.BadParameter( - f'Imported object "{app_obj}" must be "FastStream" type.', + f'Imported object "{app_obj}" must be "Application" type.', ) if log_level > 0: diff --git a/faststream/nats/response.py b/faststream/nats/response.py index b15bc97277..b3813131ff 100644 --- a/faststream/nats/response.py +++ b/faststream/nats/response.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Dict, Optional from typing_extensions import override @@ -13,7 +13,7 @@ def __init__( self, body: "SendableMessage", *, - headers: Optional["AnyDict"] = None, + headers: Optional[Dict[str, str]] = None, correlation_id: Optional[str] = None, stream: Optional[str] = None, ) -> None: diff --git a/faststream/nats/testing.py b/faststream/nats/testing.py index 08d43ed30f..6d34547d04 100644 --- a/faststream/nats/testing.py +++ b/faststream/nats/testing.py @@ -18,7 +18,7 @@ if TYPE_CHECKING: from faststream.nats.publisher.asyncapi import AsyncAPIPublisher from faststream.nats.subscriber.usecase import LogicSubscriber - from faststream.types import AnyDict, SendableMessage + from faststream.types import SendableMessage __all__ = ("TestNatsBroker",) @@ -187,7 +187,7 @@ def build_message( *, reply_to: str = "", correlation_id: Optional[str] = None, - headers: Optional["AnyDict"] = None, + headers: Optional[Dict[str, str]] = None, ) -> "PatchedMessage": msg, content_type = encode_message(message) return PatchedMessage( diff --git a/tests/brokers/base/fastapi.py b/tests/brokers/base/fastapi.py index f20dd61866..6c18b7e9d5 100644 --- a/tests/brokers/base/fastapi.py +++ b/tests/brokers/base/fastapi.py @@ -8,7 +8,7 @@ from fastapi.exceptions import RequestValidationError from fastapi.testclient import TestClient -from faststream import context +from faststream import Response, context from faststream.broker.core.usecase import BrokerUsecase from faststream.broker.fastapi.context import Context from faststream.broker.fastapi.router import StreamRouter @@ -226,6 +226,30 @@ async def hello(): ) assert r == "hi", r + async def test_request(self, queue: str): + """Local test due request exists in all TestClients.""" + router = self.router_class(setup_state=False) + + app = FastAPI() + + args, kwargs = self.get_subscriber_params(queue) + + @router.subscriber(*args, **kwargs) + async def hello(): + return Response("Hi!", headers={"x-header": "test"}) + + async with self.broker_test(router.broker): + with TestClient(app) as client: + assert not client.app_state.get("broker") + + r = await router.broker.request( + "hi", + queue, + timeout=0.5, + ) + assert await r.decode() == "Hi!" + assert r.headers["x-header"] == "test" + async def test_base_without_state(self, queue: str): router = self.router_class(setup_state=False) diff --git a/tests/cli/rabbit/test_app.py b/tests/cli/rabbit/test_app.py index 4765ffac9a..b0a401a450 100644 --- a/tests/cli/rabbit/test_app.py +++ b/tests/cli/rabbit/test_app.py @@ -364,5 +364,28 @@ async def lifespan(env: str): async_mock.broker_stopped.assert_called_once() +@pytest.mark.asyncio +async def test_run_asgi(async_mock: AsyncMock, app: FastStream): + asgi_routes = [("/", lambda scope, receive, send: None)] + asgi_app = app.as_asgi(asgi_routes=asgi_routes) + assert asgi_app.broker is app.broker + assert asgi_app.logger is app.logger + assert asgi_app.lifespan_context is app.lifespan_context + assert asgi_app._on_startup_calling is app._on_startup_calling + assert asgi_app._after_startup_calling is app._after_startup_calling + assert asgi_app._on_shutdown_calling is app._on_shutdown_calling + assert asgi_app._after_shutdown_calling is app._after_shutdown_calling + assert asgi_app.routes == asgi_routes + + with patch.object(app.broker, "start", async_mock.broker_run), patch.object( + app.broker, "close", async_mock.broker_stopped + ): + async with anyio.create_task_group() as tg: + tg.start_soon(app.run) + tg.start_soon(_kill, signal.SIGINT) + + async_mock.broker_run.assert_called_once() + + async def _kill(sig): os.kill(os.getpid(), sig) diff --git a/tests/cli/test_run.py b/tests/cli/test_run.py new file mode 100644 index 0000000000..c62e3519b9 --- /dev/null +++ b/tests/cli/test_run.py @@ -0,0 +1,83 @@ +import logging +from unittest.mock import AsyncMock, Mock, patch + +import pytest +from typer.testing import CliRunner + +from faststream.asgi import AsgiFastStream +from faststream.cli.main import cli as faststream_app + + +def test_run_as_asgi(runner: CliRunner): + app = AsgiFastStream() + app.run = AsyncMock() + + with patch("faststream.cli.main.import_from_string", return_value=(None, app)): + result = runner.invoke( + faststream_app, + [ + "run", + "faststream:app", + "--host", + "0.0.0.0", + "--port", + "8000", + ], + ) + app.run.assert_awaited_once_with( + logging.INFO, {"host": "0.0.0.0", "port": "8000"} + ) + assert result.exit_code == 0 + + +@pytest.mark.parametrize("workers", [1, 2, 5]) +def test_run_as_asgi_with_workers(runner: CliRunner, workers: int): + app = AsgiFastStream() + app.run = AsyncMock() + + with patch("faststream.cli.main.import_from_string", return_value=(None, app)): + result = runner.invoke( + faststream_app, + [ + "run", + "faststream:app", + "--host", + "0.0.0.0", + "--port", + "8000", + "--workers", + str(workers), + ], + ) + extra = {"workers": workers} if workers > 1 else {} + + app.run.assert_awaited_once_with( + logging.INFO, {"host": "0.0.0.0", "port": "8000"} | extra + ) + assert result.exit_code == 0 + + +def test_run_as_asgi_callable(runner: CliRunner): + app = AsgiFastStream() + app.run = AsyncMock() + + app_factory = Mock(return_value=app) + + with patch("faststream.cli.main.import_from_string", return_value=(None, app_factory)): + result = runner.invoke( + faststream_app, + [ + "run", + "faststream:app", + "--host", + "0.0.0.0", + "--port", + "8000", + "--factory", + ], + ) + app_factory.assert_called_once() + app.run.assert_awaited_once_with( + logging.INFO, {"host": "0.0.0.0", "port": "8000"} + ) + assert result.exit_code == 0