Skip to content

Commit

Permalink
Merge pull request #62 from alphatwirl/dev
Browse files Browse the repository at this point in the history
Clean code
  • Loading branch information
TaiSakuma authored Jul 6, 2024
2 parents d7c7e47 + 3851a4e commit 3b72113
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 140 deletions.
123 changes: 0 additions & 123 deletions atpbar/stream.py

This file was deleted.

14 changes: 14 additions & 0 deletions atpbar/stream/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
__all__ = [
'OutputStream',
'register_stream_queue',
'StreamPickup',
'StreamRedirection',
'FD',
'Queue',
'StreamQueue',
]

from .output import OutputStream, register_stream_queue
from .pickup import StreamPickup
from .redirect import StreamRedirection
from .type import FD, Queue, StreamQueue
34 changes: 34 additions & 0 deletions atpbar/stream/output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import sys
from io import TextIOBase

from .type import FD, StreamQueue


def register_stream_queue(queue: StreamQueue) -> None:
if queue is None:
return
sys.stdout = OutputStream(queue, FD.STDOUT) # type: ignore
sys.stderr = OutputStream(queue, FD.STDERR) # type: ignore


class OutputStream(TextIOBase):
def __init__(self, queue: StreamQueue, fd: FD) -> None:
self.fd = fd
self.queue = queue
self.buffer = ''

def write(self, s: str) -> int:
if not isinstance(s, str):
# The same error message as `sys.stdout.write()`
raise TypeError(f'write() argument must be str, not {type(s).__name__}')

self.buffer += s
if s.endswith('\n'):
self.flush()
return len(s)

def flush(self) -> None:
if not self.buffer:
return
self.queue.put((self.buffer, self.fd))
self.buffer = ''
35 changes: 35 additions & 0 deletions atpbar/stream/pickup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import threading
from collections.abc import Callable
from typing import Any

from .type import FD, StreamQueue


class StreamPickup(threading.Thread):
def __init__(
self,
queue: StreamQueue,
stdout_write: Callable[[str], Any],
stderr_write: Callable[[str], Any],
) -> None:
super().__init__(daemon=True)
self.queue = queue
self.stdout_write = stdout_write
self.stderr_write = stderr_write

def run(self) -> None:
try:
while True:
m = self.queue.get()
if m is None:
break
s, f = m
if f == FD.STDOUT:
self.stdout_write(s)
elif f == FD.STDERR:
self.stderr_write(s)
else:
raise ValueError('unknown fd: {!r}'.format(f))

except EOFError:
pass
44 changes: 44 additions & 0 deletions atpbar/stream/redirect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import sys

from atpbar.presentation import Presentation

from .output import OutputStream
from .pickup import StreamPickup
from .type import FD, StreamQueue


class StreamRedirection:
def __init__(self, queue: StreamQueue, presentation: Presentation) -> None:
self.disabled = not presentation.stdout_stderr_redirection
if self.disabled:
return

self.queue = queue
self.presentation = presentation

self.stdout = OutputStream(self.queue, FD.STDOUT)
self.stderr = OutputStream(self.queue, FD.STDERR)

def start(self) -> None:
if self.disabled:
return

self.pickup = StreamPickup(
self.queue, self.presentation.stdout_write, self.presentation.stderr_write
)
self.pickup.start()

self.stdout_org = sys.stdout
sys.stdout = self.stdout # type: ignore

self.stderr_org = sys.stderr
sys.stderr = self.stderr # type: ignore

def end(self) -> None:
if self.disabled:
return

sys.stdout = self.stdout_org
sys.stderr = self.stderr_org
self.queue.put(None)
self.pickup.join()
11 changes: 11 additions & 0 deletions atpbar/stream/type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from enum import Enum
from multiprocessing import Queue
from typing import TypeAlias


class FD(Enum):
STDOUT = 1
STDERR = 2


StreamQueue: TypeAlias = 'Queue[tuple[str, FD] | None]'
Empty file added tests/stream/output/__init__.py
Empty file.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from pytest import MonkeyPatch, fixture

from atpbar.stream import FD, Stream
from atpbar.stream import FD, OutputStream


@fixture()
Expand All @@ -12,11 +12,11 @@ def mock_queue() -> Mock:


@fixture()
def obj(mock_queue: Mock) -> Stream:
return Stream(mock_queue, FD.STDOUT)
def obj(mock_queue: Mock) -> OutputStream:
return OutputStream(mock_queue, FD.STDOUT)


def test_print(mock_queue: Mock, obj: Stream, monkeypatch: MonkeyPatch) -> None:
def test_print(mock_queue: Mock, obj: OutputStream, monkeypatch: MonkeyPatch) -> None:
with monkeypatch.context() as m:
m.setattr(sys, 'stdout', obj)

Expand Down Expand Up @@ -53,15 +53,17 @@ def test_print(mock_queue: Mock, obj: Stream, monkeypatch: MonkeyPatch) -> None:
assert [call(('abc', FD.STDOUT))] == mock_queue.put.call_args_list


def test_print_bytes(mock_queue: Mock, obj: Stream, monkeypatch: MonkeyPatch) -> None:
def test_print_bytes(
mock_queue: Mock, obj: OutputStream, monkeypatch: MonkeyPatch
) -> None:
with monkeypatch.context() as m:
m.setattr(sys, 'stdout', obj)

print(b'abc')
assert [call(("b'abc'\n", FD.STDOUT))] == mock_queue.put.call_args_list


def test_stdout(mock_queue: Mock, obj: Stream, monkeypatch: MonkeyPatch) -> None:
def test_stdout(mock_queue: Mock, obj: OutputStream, monkeypatch: MonkeyPatch) -> None:
with monkeypatch.context() as m:
m.setattr(sys, 'stdout', obj)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,30 +1,36 @@
from unittest.mock import sentinel

import pytest
from hypothesis import given, settings
from hypothesis import strategies as st

from atpbar.stream import Queue, Stream, StreamQueue
from atpbar.stream import OutputStream, Queue, StreamQueue
from tests.stream.st import st_text

from .st import st_text

def test_type_error() -> None:
stream = OutputStream(Queue(), sentinel.fd)
with pytest.raises(TypeError):
stream.write(123) # type: ignore


class StatefulTest:
def __init__(self, data: st.DataObject) -> None:
self.draw = data.draw
self.queue: StreamQueue = Queue()
self.fd = sentinel.fd
self.stream = Stream(self.queue, self.fd)
self.stream = OutputStream(self.queue, self.fd)
self.written = list[str]()

def write(self) -> None:
text = self.draw(st_text())
self.stream.write(text)
assert self.stream.write(text) == len(text)
self.written.append(text)

def write_with_newline(self) -> None:
text = self.draw(st_text())
text += '\n'
self.stream.write(text)
assert self.stream.write(text) == len(text)
self.written.append(text)
expected = ''.join(self.written)
assert self.queue.get() == (expected, self.fd)
Expand All @@ -41,7 +47,7 @@ def flush(self) -> None:

@settings(max_examples=200)
@given(data=st.data())
def test_stream(data: st.DataObject) -> None:
def test_stateful(data: st.DataObject) -> None:
test = StatefulTest(data=data)
METHODS = [test.write, test.write_with_newline, test.flush]
methods = data.draw(st.lists(st.sampled_from(METHODS)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from hypothesis import strategies as st
from pytest import MonkeyPatch

from atpbar.stream import FD, Stream

from .st import st_text
from atpbar.stream import FD, OutputStream
from tests.stream.st import st_text


def st_end() -> st.SearchStrategy[str | None]:
Expand All @@ -27,8 +26,8 @@ def test_print(
texts: list[str], end: str | None, flush: bool, fd_name: Literal['stdout', 'stderr']
) -> None:
queue = Mock()
stdout = Stream(queue, fd=FD.STDOUT)
stderr = Stream(queue, fd=FD.STDERR)
stdout = OutputStream(queue, fd=FD.STDOUT)
stderr = OutputStream(queue, fd=FD.STDERR)

with MonkeyPatch.context() as m:
m.setattr(sys, 'stdout', stdout)
Expand Down

0 comments on commit 3b72113

Please sign in to comment.