Skip to content

Commit

Permalink
Add @overload to contrib.regular_languages.compiler.Variables.get
Browse files Browse the repository at this point in the history
- Improves type hints for the `Variables` class.
- Fixes ruff pipeline check.
- Fixes several mypy issues.
  • Loading branch information
tchalupnik authored Sep 25, 2024
1 parent 6695411 commit 75615b1
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
pip list
- name: Ruff
run: |
ruff .
ruff check .
ruff format --check .
typos .
- name: Tests
Expand Down
17 changes: 12 additions & 5 deletions src/prompt_toolkit/contrib/regular_languages/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from __future__ import annotations

import re
from typing import Callable, Dict, Iterable, Iterator, Pattern
from typing import Callable, Dict, Iterable, Iterator, Pattern, TypeVar, overload
from typing import Match as RegexMatch

from .regex_parser import (
Expand All @@ -57,9 +57,7 @@
tokenize_regex,
)

__all__ = [
"compile",
]
__all__ = ["compile", "Match", "Variables"]


# Name of the named group in the regex, matching trailing input.
Expand Down Expand Up @@ -491,6 +489,9 @@ def end_nodes(self) -> Iterable[MatchVariable]:
yield MatchVariable(varname, value, (reg[0], reg[1]))


_T = TypeVar("_T")


class Variables:
def __init__(self, tuples: list[tuple[str, str, tuple[int, int]]]) -> None:
#: List of (varname, value, slice) tuples.
Expand All @@ -502,7 +503,13 @@ def __repr__(self) -> str:
", ".join(f"{k}={v!r}" for k, v, _ in self._tuples),
)

def get(self, key: str, default: str | None = None) -> str | None:
@overload
def get(self, key: str) -> str | None: ...

@overload
def get(self, key: str, default: str | _T) -> str | _T: ...

def get(self, key: str, default: str | _T | None = None) -> str | _T | None:
items = self.getall(key)
return items[0] if items else default

Expand Down
22 changes: 13 additions & 9 deletions src/prompt_toolkit/output/defaults.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import sys
from typing import TextIO, cast
from typing import TYPE_CHECKING, TextIO, cast

from prompt_toolkit.utils import (
get_bell_environment_variable,
Expand All @@ -13,13 +13,17 @@
from .color_depth import ColorDepth
from .plain_text import PlainTextOutput

if TYPE_CHECKING:
from prompt_toolkit.patch_stdout import StdoutProxy


__all__ = [
"create_output",
]


def create_output(
stdout: TextIO | None = None, always_prefer_tty: bool = False
stdout: TextIO | StdoutProxy | None = None, always_prefer_tty: bool = False
) -> Output:
"""
Return an :class:`~prompt_toolkit.output.Output` instance for the command
Expand Down Expand Up @@ -54,13 +58,6 @@ def create_output(
stdout = io
break

# If the output is still `None`, use a DummyOutput.
# This happens for instance on Windows, when running the application under
# `pythonw.exe`. In that case, there won't be a terminal Window, and
# stdin/stdout/stderr are `None`.
if stdout is None:
return DummyOutput()

# If the patch_stdout context manager has been used, then sys.stdout is
# replaced by this proxy. For prompt_toolkit applications, we want to use
# the real stdout.
Expand All @@ -69,6 +66,13 @@ def create_output(
while isinstance(stdout, StdoutProxy):
stdout = stdout.original_stdout

# If the output is still `None`, use a DummyOutput.
# This happens for instance on Windows, when running the application under
# `pythonw.exe`. In that case, there won't be a terminal Window, and
# stdin/stdout/stderr are `None`.
if stdout is None:
return DummyOutput()

if sys.platform == "win32":
from .conemu import ConEmuOutput
from .win32 import Win32Output
Expand Down
2 changes: 1 addition & 1 deletion src/prompt_toolkit/patch_stdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def flush(self) -> None:
self._flush()

@property
def original_stdout(self) -> TextIO:
def original_stdout(self) -> TextIO | None:
return self._output.stdout or sys.__stdout__

# Attributes for compatibility with sys.__stdout__:
Expand Down

0 comments on commit 75615b1

Please sign in to comment.