diff --git a/atpbar/stream.py b/atpbar/stream.py deleted file mode 100644 index 8a18e45..0000000 --- a/atpbar/stream.py +++ /dev/null @@ -1,123 +0,0 @@ -import sys -import threading -from collections.abc import Callable -from enum import Enum -from io import TextIOBase -from multiprocessing import Queue -from typing import Any, TypeAlias - -from .presentation import Presentation - - -class FD(Enum): - STDOUT = 1 - STDERR = 2 - - -StreamQueue: TypeAlias = 'Queue[tuple[str, FD] | None]' - - -class StreamRedirection: - def __init__(self, queue: StreamQueue, presentation: Presentation) -> None: - self.disabled = not presentation.stdout_stderr_redirection - if self.disabled: - return - - self.queue = queue - self.presentation = presentation - - self.stdout = Stream(self.queue, FD.STDOUT) - self.stderr = Stream(self.queue, FD.STDERR) - - def start(self) -> None: - if self.disabled: - return - - self.pickup = StreamPickup( - self.queue, self.presentation.stdout_write, self.presentation.stderr_write - ) - self.pickup.start() - - self.stdout_org = sys.stdout - sys.stdout = self.stdout # type: ignore - - self.stderr_org = sys.stderr - sys.stderr = self.stderr # type: ignore - - def end(self) -> None: - if self.disabled: - return - - sys.stdout = self.stdout_org - sys.stderr = self.stderr_org - self.queue.put(None) - self.pickup.join() - - -def register_stream_queue(queue: StreamQueue) -> None: - if queue is None: - return - sys.stdout = Stream(queue, FD.STDOUT) # type: ignore - sys.stderr = Stream(queue, FD.STDERR) # type: ignore - - -class Stream(TextIOBase): - def __init__(self, queue: StreamQueue, fd: FD) -> None: - self.fd = fd - self.queue = queue - self.buffer = '' - - def write(self, s: str) -> int: - # sys.__stdout__.write(repr(s)) - # sys.__stdout__.write('\n') - - try: - endswith_n = s.endswith('\n') - except: - self.flush() - self.queue.put((s, self.fd)) - return len(s) - - if endswith_n: - self.buffer += s - self.flush() - return len(s) - - self.buffer += s - return len(s) - - def flush(self) -> None: - if not self.buffer: - return - self.queue.put((self.buffer, self.fd)) - self.buffer = '' - - -class StreamPickup(threading.Thread): - def __init__( - self, - queue: StreamQueue, - stdout_write: Callable[[str], Any], - stderr_write: Callable[[str], Any], - ) -> None: - super().__init__(daemon=True) - self.queue = queue - self.stdout_write = stdout_write - self.stderr_write = stderr_write - - def run(self) -> None: - try: - while True: - m = self.queue.get() - if m is None: - break - s, f = m - if f == FD.STDOUT: - self.stdout_write(s) - elif f == FD.STDERR: - self.stderr_write(s) - else: - raise ValueError('unknown fd: {!r}'.format(f)) - - except EOFError: - pass diff --git a/atpbar/stream/__init__.py b/atpbar/stream/__init__.py new file mode 100644 index 0000000..9ff16dc --- /dev/null +++ b/atpbar/stream/__init__.py @@ -0,0 +1,14 @@ +__all__ = [ + 'OutputStream', + 'register_stream_queue', + 'StreamPickup', + 'StreamRedirection', + 'FD', + 'Queue', + 'StreamQueue', +] + +from .output import OutputStream, register_stream_queue +from .pickup import StreamPickup +from .redirect import StreamRedirection +from .type import FD, Queue, StreamQueue diff --git a/atpbar/stream/output.py b/atpbar/stream/output.py new file mode 100644 index 0000000..484f4ba --- /dev/null +++ b/atpbar/stream/output.py @@ -0,0 +1,34 @@ +import sys +from io import TextIOBase + +from .type import FD, StreamQueue + + +def register_stream_queue(queue: StreamQueue) -> None: + if queue is None: + return + sys.stdout = OutputStream(queue, FD.STDOUT) # type: ignore + sys.stderr = OutputStream(queue, FD.STDERR) # type: ignore + + +class OutputStream(TextIOBase): + def __init__(self, queue: StreamQueue, fd: FD) -> None: + self.fd = fd + self.queue = queue + self.buffer = '' + + def write(self, s: str) -> int: + if not isinstance(s, str): + # The same error message as `sys.stdout.write()` + raise TypeError(f'write() argument must be str, not {type(s).__name__}') + + self.buffer += s + if s.endswith('\n'): + self.flush() + return len(s) + + def flush(self) -> None: + if not self.buffer: + return + self.queue.put((self.buffer, self.fd)) + self.buffer = '' diff --git a/atpbar/stream/pickup.py b/atpbar/stream/pickup.py new file mode 100644 index 0000000..eb58ec7 --- /dev/null +++ b/atpbar/stream/pickup.py @@ -0,0 +1,35 @@ +import threading +from collections.abc import Callable +from typing import Any + +from .type import FD, StreamQueue + + +class StreamPickup(threading.Thread): + def __init__( + self, + queue: StreamQueue, + stdout_write: Callable[[str], Any], + stderr_write: Callable[[str], Any], + ) -> None: + super().__init__(daemon=True) + self.queue = queue + self.stdout_write = stdout_write + self.stderr_write = stderr_write + + def run(self) -> None: + try: + while True: + m = self.queue.get() + if m is None: + break + s, f = m + if f == FD.STDOUT: + self.stdout_write(s) + elif f == FD.STDERR: + self.stderr_write(s) + else: + raise ValueError('unknown fd: {!r}'.format(f)) + + except EOFError: + pass diff --git a/atpbar/stream/redirect.py b/atpbar/stream/redirect.py new file mode 100644 index 0000000..dc4e6f0 --- /dev/null +++ b/atpbar/stream/redirect.py @@ -0,0 +1,44 @@ +import sys + +from atpbar.presentation import Presentation + +from .output import OutputStream +from .pickup import StreamPickup +from .type import FD, StreamQueue + + +class StreamRedirection: + def __init__(self, queue: StreamQueue, presentation: Presentation) -> None: + self.disabled = not presentation.stdout_stderr_redirection + if self.disabled: + return + + self.queue = queue + self.presentation = presentation + + self.stdout = OutputStream(self.queue, FD.STDOUT) + self.stderr = OutputStream(self.queue, FD.STDERR) + + def start(self) -> None: + if self.disabled: + return + + self.pickup = StreamPickup( + self.queue, self.presentation.stdout_write, self.presentation.stderr_write + ) + self.pickup.start() + + self.stdout_org = sys.stdout + sys.stdout = self.stdout # type: ignore + + self.stderr_org = sys.stderr + sys.stderr = self.stderr # type: ignore + + def end(self) -> None: + if self.disabled: + return + + sys.stdout = self.stdout_org + sys.stderr = self.stderr_org + self.queue.put(None) + self.pickup.join() diff --git a/atpbar/stream/type.py b/atpbar/stream/type.py new file mode 100644 index 0000000..1ee76fc --- /dev/null +++ b/atpbar/stream/type.py @@ -0,0 +1,11 @@ +from enum import Enum +from multiprocessing import Queue +from typing import TypeAlias + + +class FD(Enum): + STDOUT = 1 + STDERR = 2 + + +StreamQueue: TypeAlias = 'Queue[tuple[str, FD] | None]' diff --git a/tests/stream/output/__init__.py b/tests/stream/output/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/stream/test_example.py b/tests/stream/output/test_example.py similarity index 82% rename from tests/stream/test_example.py rename to tests/stream/output/test_example.py index 4b86e5b..5b4d849 100644 --- a/tests/stream/test_example.py +++ b/tests/stream/output/test_example.py @@ -3,7 +3,7 @@ from pytest import MonkeyPatch, fixture -from atpbar.stream import FD, Stream +from atpbar.stream import FD, OutputStream @fixture() @@ -12,11 +12,11 @@ def mock_queue() -> Mock: @fixture() -def obj(mock_queue: Mock) -> Stream: - return Stream(mock_queue, FD.STDOUT) +def obj(mock_queue: Mock) -> OutputStream: + return OutputStream(mock_queue, FD.STDOUT) -def test_print(mock_queue: Mock, obj: Stream, monkeypatch: MonkeyPatch) -> None: +def test_print(mock_queue: Mock, obj: OutputStream, monkeypatch: MonkeyPatch) -> None: with monkeypatch.context() as m: m.setattr(sys, 'stdout', obj) @@ -53,7 +53,9 @@ def test_print(mock_queue: Mock, obj: Stream, monkeypatch: MonkeyPatch) -> None: assert [call(('abc', FD.STDOUT))] == mock_queue.put.call_args_list -def test_print_bytes(mock_queue: Mock, obj: Stream, monkeypatch: MonkeyPatch) -> None: +def test_print_bytes( + mock_queue: Mock, obj: OutputStream, monkeypatch: MonkeyPatch +) -> None: with monkeypatch.context() as m: m.setattr(sys, 'stdout', obj) @@ -61,7 +63,7 @@ def test_print_bytes(mock_queue: Mock, obj: Stream, monkeypatch: MonkeyPatch) -> assert [call(("b'abc'\n", FD.STDOUT))] == mock_queue.put.call_args_list -def test_stdout(mock_queue: Mock, obj: Stream, monkeypatch: MonkeyPatch) -> None: +def test_stdout(mock_queue: Mock, obj: OutputStream, monkeypatch: MonkeyPatch) -> None: with monkeypatch.context() as m: m.setattr(sys, 'stdout', obj) diff --git a/tests/stream/test_stream.py b/tests/stream/output/test_output_stream.py similarity index 71% rename from tests/stream/test_stream.py rename to tests/stream/output/test_output_stream.py index ad2acd3..8b547fb 100644 --- a/tests/stream/test_stream.py +++ b/tests/stream/output/test_output_stream.py @@ -1,11 +1,17 @@ from unittest.mock import sentinel +import pytest from hypothesis import given, settings from hypothesis import strategies as st -from atpbar.stream import Queue, Stream, StreamQueue +from atpbar.stream import OutputStream, Queue, StreamQueue +from tests.stream.st import st_text -from .st import st_text + +def test_type_error() -> None: + stream = OutputStream(Queue(), sentinel.fd) + with pytest.raises(TypeError): + stream.write(123) # type: ignore class StatefulTest: @@ -13,18 +19,18 @@ def __init__(self, data: st.DataObject) -> None: self.draw = data.draw self.queue: StreamQueue = Queue() self.fd = sentinel.fd - self.stream = Stream(self.queue, self.fd) + self.stream = OutputStream(self.queue, self.fd) self.written = list[str]() def write(self) -> None: text = self.draw(st_text()) - self.stream.write(text) + assert self.stream.write(text) == len(text) self.written.append(text) def write_with_newline(self) -> None: text = self.draw(st_text()) text += '\n' - self.stream.write(text) + assert self.stream.write(text) == len(text) self.written.append(text) expected = ''.join(self.written) assert self.queue.get() == (expected, self.fd) @@ -41,7 +47,7 @@ def flush(self) -> None: @settings(max_examples=200) @given(data=st.data()) -def test_stream(data: st.DataObject) -> None: +def test_stateful(data: st.DataObject) -> None: test = StatefulTest(data=data) METHODS = [test.write, test.write_with_newline, test.flush] methods = data.draw(st.lists(st.sampled_from(METHODS))) diff --git a/tests/stream/test_print.py b/tests/stream/output/test_print.py similarity index 93% rename from tests/stream/test_print.py rename to tests/stream/output/test_print.py index 3c50bec..7cf5cdc 100644 --- a/tests/stream/test_print.py +++ b/tests/stream/output/test_print.py @@ -6,9 +6,8 @@ from hypothesis import strategies as st from pytest import MonkeyPatch -from atpbar.stream import FD, Stream - -from .st import st_text +from atpbar.stream import FD, OutputStream +from tests.stream.st import st_text def st_end() -> st.SearchStrategy[str | None]: @@ -27,8 +26,8 @@ def test_print( texts: list[str], end: str | None, flush: bool, fd_name: Literal['stdout', 'stderr'] ) -> None: queue = Mock() - stdout = Stream(queue, fd=FD.STDOUT) - stderr = Stream(queue, fd=FD.STDERR) + stdout = OutputStream(queue, fd=FD.STDOUT) + stderr = OutputStream(queue, fd=FD.STDERR) with MonkeyPatch.context() as m: m.setattr(sys, 'stdout', stdout)