From 89540d6d4ffcdc37fcf4c5addd36195c8a6935ed Mon Sep 17 00:00:00 2001 From: Tai Sakuma Date: Tue, 25 Jun 2024 10:04:01 -0400 Subject: [PATCH] Add `subscribe()` to `CacheStdout` --- nextlinegraphql/plugins/ctrl/__init__.py | 7 ++++--- nextlinegraphql/plugins/ctrl/cache.py | 14 ++++++++++++-- .../plugins/ctrl/schema/subscription.py | 16 ++++++---------- .../ctrl/schema/subscriptions/test_stdout.py | 6 +++--- 4 files changed, 25 insertions(+), 18 deletions(-) diff --git a/nextlinegraphql/plugins/ctrl/__init__.py b/nextlinegraphql/plugins/ctrl/__init__.py index 3409733..0a3019c 100644 --- a/nextlinegraphql/plugins/ctrl/__init__.py +++ b/nextlinegraphql/plugins/ctrl/__init__.py @@ -20,18 +20,19 @@ def schema(self) -> tuple[type, type | None, type | None]: @spec.hookimpl async def update_lifespan_context(self, context: MutableMapping) -> None: self._nextline = Nextline(statement) - self._stdout_cache = list[str]() context['nextline'] = self._nextline @spec.hookimpl(trylast=True) # trylast so to be the innermost context @asynccontextmanager async def lifespan(self) -> AsyncIterator[None]: '''Yield within the nextline context.''' - self._nextline.register(CacheStdout(self._stdout_cache)) + self._cache_stdout = CacheStdout(self._nextline) + self._nextline.register(self._cache_stdout) async with self._nextline: yield @spec.hookimpl def update_strawberry_context(self, context: MutableMapping) -> None: context['nextline'] = self._nextline - context['stdout_cache'] = self._stdout_cache + ctrl = {'cache_stdout': self._cache_stdout} + context['ctrl'] = ctrl diff --git a/nextlinegraphql/plugins/ctrl/cache.py b/nextlinegraphql/plugins/ctrl/cache.py index 87cf55c..46e764f 100644 --- a/nextlinegraphql/plugins/ctrl/cache.py +++ b/nextlinegraphql/plugins/ctrl/cache.py @@ -1,10 +1,20 @@ +from collections.abc import AsyncIterator + +from nextline import Nextline from nextline.events import OnWriteStdout from nextline.plugin.spec import hookimpl class CacheStdout: - def __init__(self, cache: list[str]) -> None: - self._cache = cache + def __init__(self, nextline: Nextline) -> None: + self._nextline = nextline + self._cache = list[str]() + + async def subscribe(self) -> AsyncIterator[str]: + yield ''.join(self._cache) + async for i in self._nextline.subscribe_stdout(): + assert i.text is not None + yield i.text @hookimpl async def on_initialize_run(self) -> None: diff --git a/nextlinegraphql/plugins/ctrl/schema/subscription.py b/nextlinegraphql/plugins/ctrl/schema/subscription.py index da7794e..708cc94 100644 --- a/nextlinegraphql/plugins/ctrl/schema/subscription.py +++ b/nextlinegraphql/plugins/ctrl/schema/subscription.py @@ -1,12 +1,11 @@ import asyncio from collections.abc import AsyncIterator -from typing import TYPE_CHECKING import strawberry +from nextline import Nextline from strawberry.types import Info -if TYPE_CHECKING: - from nextline import Nextline +from nextlinegraphql.plugins.ctrl.cache import CacheStdout @strawberry.type @@ -57,13 +56,10 @@ async def subscribe_prompting( yield y -async def subscribe_stdout(info: Info) -> AsyncIterator[str]: - nextline: Nextline = info.context["nextline"] - stdout_cache: list[str] = info.context["stdout_cache"] - yield ''.join(stdout_cache) - async for i in nextline.subscribe_stdout(): - assert i.text is not None - yield i.text +def subscribe_stdout(info: Info) -> AsyncIterator[str]: + cache_stdout = info.context['ctrl']['cache_stdout'] + assert isinstance(cache_stdout, CacheStdout) + return cache_stdout.subscribe() def subscribe_continuous_enabled(info: Info) -> AsyncIterator[bool]: diff --git a/tests/plugins/ctrl/schema/subscriptions/test_stdout.py b/tests/plugins/ctrl/schema/subscriptions/test_stdout.py index 8359325..3d0ba82 100644 --- a/tests/plugins/ctrl/schema/subscriptions/test_stdout.py +++ b/tests/plugins/ctrl/schema/subscriptions/test_stdout.py @@ -15,10 +15,10 @@ async def test_schema(schema: Schema) -> None: nextline = Nextline(SOURCE, trace_modules=True, trace_threads=True) - cache = list[str]() - nextline.register(CacheStdout(cache)) + cache_stdout = CacheStdout(nextline) + nextline.register(cache_stdout) started = asyncio.Event() - context = {'nextline': nextline, 'stdout_cache': cache} + context = {'nextline': nextline, 'ctrl': {'cache_stdout': cache_stdout}} async with nextline: task = asyncio.create_task(nextline.run_continue_and_wait(started=started)) await started.wait()