diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..7c955ab --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,18 @@ +[tool.mypy] +check_untyped_defs = true +disallow_any_generics = true +disallow_untyped_calls = true +disallow_untyped_defs = true +ignore_missing_imports = true +no_implicit_optional = true +no_implicit_reexport = true +show_column_numbers = true +show_error_codes = true +show_traceback = true +strict = true +strict_equality = true +warn_redundant_casts = true +warn_return_any = true +warn_unreachable = true +warn_unused_configs = true +warn_unused_ignores = true diff --git a/rere.py b/rere.py index 4df44ae..698f014 100755 --- a/rere.py +++ b/rere.py @@ -23,7 +23,7 @@ import sys import subprocess from difflib import unified_diff -from typing import List, BinaryIO, Tuple, Optional +from typing import BinaryIO, TypedDict def read_blob_field(f: BinaryIO, name: bytes) -> bytes: line = f.readline() @@ -42,15 +42,21 @@ def read_int_field(f: BinaryIO, name: bytes) -> int: assert line.endswith(b'\n') return int(line[len(field):-1]) -def write_int_field(f: BinaryIO, name: bytes, value: int): +def write_int_field(f: BinaryIO, name: bytes, value: int) -> None: f.write(b':i %s %d\n' % (name, value)) -def write_blob_field(f: BinaryIO, name: bytes, blob: bytes): +def write_blob_field(f: BinaryIO, name: bytes, blob: bytes) -> None: f.write(b':b %s %d\n' % (name, len(blob))) f.write(blob) f.write(b'\n') -def capture(shell: str) -> dict: +class Snapshot(TypedDict): + shell: str + returncode: int + stdout: bytes + stderr: bytes + +def capture(shell: str) -> Snapshot: print(f"CAPTURING: {shell}") process = subprocess.run(['sh', '-c', shell], capture_output = True) return { @@ -64,7 +70,7 @@ def load_list(file_path: str) -> list[str]: with open(file_path) as f: return [line.strip() for line in f] -def dump_snapshots(file_path: str, snapshots: list[dict]): +def dump_snapshots(file_path: str, snapshots: list[Snapshot]) -> None: with open(file_path, "wb") as f: write_int_field(f, b"count", len(snapshots)) for snapshot in snapshots: @@ -73,21 +79,21 @@ def dump_snapshots(file_path: str, snapshots: list[dict]): write_blob_field(f, b"stdout", snapshot['stdout']) write_blob_field(f, b"stderr", snapshot['stderr']) -def load_snapshots(file_path: str) -> list[dict]: +def load_snapshots(file_path: str) -> list[Snapshot]: snapshots = [] with open(file_path, "rb") as f: count = read_int_field(f, b"count") for _ in range(count): - shell = read_blob_field(f, b"shell") + shell: bytes = read_blob_field(f, b"shell") returncode = read_int_field(f, b"returncode") stdout = read_blob_field(f, b"stdout") stderr = read_blob_field(f, b"stderr") - snapshot = { - "shell": shell, + snapshot = Snapshot({ + "shell": shell.decode("utf-8"), "returncode": returncode, "stdout": stdout, "stderr": stderr, - } + }) snapshots.append(snapshot) return snapshots @@ -128,7 +134,7 @@ def load_snapshots(file_path: str) -> list[dict]: for (shell, snapshot) in zip(shells, snapshots): print(f"REPLAYING: {shell}") - snapshot_shell = snapshot['shell'].decode('utf-8') + snapshot_shell = snapshot['shell'] if shell != snapshot_shell: print(f"UNEXPECTED: shell command") print(f" EXPECTED: {snapshot_shell}")