diff --git a/ptpython/printer.py b/ptpython/printer.py new file mode 100644 index 0000000..54e3338 --- /dev/null +++ b/ptpython/printer.py @@ -0,0 +1,415 @@ +from __future__ import annotations + +import sys +import traceback +from dataclasses import dataclass +from enum import Enum +from typing import Generator, Iterable + +from prompt_toolkit.formatted_text import ( + HTML, + AnyFormattedText, + FormattedText, + OneStyleAndTextTuple, + StyleAndTextTuples, + fragment_list_width, + merge_formatted_text, + to_formatted_text, +) +from prompt_toolkit.formatted_text.utils import split_lines +from prompt_toolkit.key_binding import KeyBindings, KeyPressEvent +from prompt_toolkit.output import Output +from prompt_toolkit.shortcuts import PromptSession, print_formatted_text +from prompt_toolkit.styles import BaseStyle, StyleTransformation +from prompt_toolkit.styles.pygments import pygments_token_to_classname +from prompt_toolkit.utils import get_cwidth +from pygments.lexers import PythonLexer, PythonTracebackLexer + +__all__ = ["OutputPrinter"] + +# Never reformat results larger than this: +MAX_REFORMAT_SIZE = 1_000_000 + + +@dataclass +class OutputPrinter: + """ + Result printer. + + Usage:: + + printer = OutputPrinter(...) + printer.display_result(...) + printer.display_exception(...) + """ + + output: Output + style: BaseStyle + title: AnyFormattedText + style_transformation: StyleTransformation + + def display_result( + self, + result: object, + *, + out_prompt: AnyFormattedText, + reformat: bool, + highlight: bool, + paginate: bool, + ) -> None: + """ + Show __repr__ (or `__pt_repr__`) for an `eval` result and print to output. + + :param reformat: Reformat result using 'black' before printing if the + result is parsable as Python code. + :param highlight: Syntax highlight the result. + :param paginate: Show paginator when the result does not fit on the + screen. + """ + out_prompt = to_formatted_text(out_prompt) + out_prompt_width = fragment_list_width(out_prompt) + + result = self._insert_out_prompt_and_split_lines( + self._format_result_output( + result, + reformat=reformat, + highlight=highlight, + line_length=self.output.get_size().columns - out_prompt_width, + ), + out_prompt=out_prompt, + ) + self._display_result(result, paginate=paginate) + + def display_exception( + self, e: BaseException, *, highlight: bool, paginate: bool + ) -> None: + """ + Render an exception. + """ + result = self._insert_out_prompt_and_split_lines( + self._format_exception_output(e, highlight=highlight), + out_prompt="", + ) + self._display_result(result, paginate=paginate) + + def _display_result( + self, + result: Iterable[list[StyleAndTextTuples]], + *, + paginate: bool, + ) -> None: + if paginate: + result = self._apply_soft_wrapping(result) + self._print_paginated_formatted_text(result) + else: + for line in result: + self._print_formatted_text(line) + + self.output.flush() + + def _print_formatted_text(self, line: StyleAndTextTuples, end: str = "\n") -> None: + print_formatted_text( + FormattedText(line), + style=self.style, + style_transformation=self.style_transformation, + include_default_pygments_style=False, + output=self.output, + end=end, + ) + + def _format_result_output( + self, + result: object, + *, + reformat: bool, + highlight: bool, + line_length: int, + ) -> Generator[OneStyleAndTextTuple, None, None]: + """ + Format __repr__ for an `eval` result. + + Note: this can raise `KeyboardInterrupt` if either calling `__repr__`, + `__pt_repr__` or formatting the output with "Black" takes to long + and the user presses Control-C. + """ + # If __pt_repr__ is present, take this. This can return prompt_toolkit + # formatted text. + try: + if hasattr(result, "__pt_repr__"): + formatted_result_repr = to_formatted_text( + getattr(result, "__pt_repr__")() + ) + yield from formatted_result_repr + return + except KeyboardInterrupt: + raise # Don't catch here. + except: + # For bad code, `__getattr__` can raise something that's not an + # `AttributeError`. This happens already when calling `hasattr()`. + pass + + # Call `__repr__` of given object first, to turn it in a string. + try: + result_repr = repr(result) + except KeyboardInterrupt: + raise # Don't catch here. + except BaseException as e: + # Calling repr failed. + self.display_exception(e) + return + + # Determine whether it's valid Python code. If not, + # reformatting/highlighting won't be applied. + if len(result_repr) < MAX_REFORMAT_SIZE: + try: + compile(result_repr, "", "eval") + except SyntaxError: + valid_python = False + else: + valid_python = True + else: + valid_python = False + + if valid_python and reformat: + # Inline import. Slightly speed up start-up time if black is + # not used. + try: + import black + + if not hasattr(black, "Mode"): + raise ImportError + except ImportError: + pass # no Black package in your installation + else: + result_repr = black.format_str( + result_repr, + mode=black.Mode(line_length=line_length), + ) + + if valid_python and highlight: + yield from _lex_python_result(result_repr) + else: + yield ("", result_repr) + + def _insert_out_prompt_and_split_lines( + self, result: Iterable[OneStyleAndTextTuple], out_prompt: AnyFormattedText + ) -> Iterable[StyleAndTextTuples]: + r""" + Split styled result in lines (based on the \n characters in the result) + an insert output prompt on whitespace in front of each line. (This does + not yet do the soft wrapping.) + + Yield lines as a result. + """ + out_prompt = to_formatted_text(out_prompt) + out_prompt_width = fragment_list_width(out_prompt) + prefix = ("", " " * out_prompt_width) + + for i, line in enumerate(split_lines(result)): + if i == 0: + line = [*out_prompt, *line] + else: + line = [prefix, *line] + yield line + + def _apply_soft_wrapping( + self, lines: Iterable[StyleAndTextTuples] + ) -> Iterable[StyleAndTextTuples]: + """ + Apply soft wrapping to the given lines. Wrap according to the terminal + width. Insert whitespace in front of each wrapped line to align it with + the output prompt. + """ + line_length = self.output.get_size().columns + + # Iterate over hard wrapped lines. + for lineno, line in enumerate(lines): + columns_in_buffer = 0 + current_line: list[OneStyleAndTextTuple] = [] + + for style, text, *_ in line: + for c in text: + width = get_cwidth(c) + + # (Soft) wrap line if it doesn't fit. + if columns_in_buffer + width > line_length: + yield current_line + columns_in_buffer = 0 + current_line = [] + + columns_in_buffer += width + current_line.append((style, c)) + + if len(current_line) > 0: + yield current_line + + def _print_paginated_formatted_text( + self, lines: Iterable[StyleAndTextTuples] + ) -> None: + """ + Print formatted text, using --MORE-- style pagination. + (Avoid filling up the terminal's scrollback buffer.) + """ + pager_prompt = create_pager_prompt(self.style, self.title) + + abort = False + print_all = False + + # Max number of lines allowed in the buffer before painting. + size = self.output.get_size() + max_rows = size.rows - 1 + + # Page buffer. + page: StyleAndTextTuples = [] + + def show_pager() -> None: + nonlocal abort, max_rows, print_all + + # Run pager prompt in another thread. + # Same as for the input. This prevents issues with nested event + # loops. + pager_result = pager_prompt.prompt(in_thread=True) + + if pager_result == PagerResult.ABORT: + print("...") + abort = True + + elif pager_result == PagerResult.NEXT_LINE: + max_rows = 1 + + elif pager_result == PagerResult.NEXT_PAGE: + max_rows = size.rows - 1 + + elif pager_result == PagerResult.PRINT_ALL: + print_all = True + + # Loop over lines. Show --MORE-- prompt when page is filled. + rows = 0 + + for lineno, line in enumerate(lines): + page.extend(line) + page.append(("", "\n")) + rows += 1 + + if rows >= max_rows: + self._print_formatted_text(page, end="") + page = [] + rows = 0 + + if not print_all: + show_pager() + if abort: + return + + self._print_formatted_text(page) + + def _format_exception_output( + self, e: BaseException, highlight: bool + ) -> Generator[OneStyleAndTextTuple, None, None]: + # Instead of just calling ``traceback.format_exc``, we take the + # traceback and skip the bottom calls of this framework. + t, v, tb = sys.exc_info() + + # Required for pdb.post_mortem() to work. + sys.last_type, sys.last_value, sys.last_traceback = t, v, tb + + tblist = list(traceback.extract_tb(tb)) + + for line_nr, tb_tuple in enumerate(tblist): + if tb_tuple[0] == "": + tblist = tblist[line_nr:] + break + + tb_list = traceback.format_list(tblist) + if tb_list: + tb_list.insert(0, "Traceback (most recent call last):\n") + tb_list.extend(traceback.format_exception_only(t, v)) + + tb_str = "".join(tb_list) + + # Format exception and write to output. + # (We use the default style. Most other styles result + # in unreadable colors for the traceback.) + if highlight: + for index, tokentype, text in PythonTracebackLexer().get_tokens_unprocessed( + tb_str + ): + yield ("class:" + pygments_token_to_classname(tokentype), text) + else: + return ("", tb_str) + + +class PagerResult(Enum): + ABORT = "ABORT" + NEXT_LINE = "NEXT_LINE" + NEXT_PAGE = "NEXT_PAGE" + PRINT_ALL = "PRINT_ALL" + + +def create_pager_prompt( + style: BaseStyle, title: AnyFormattedText = "" +) -> PromptSession[PagerResult]: + """ + Create a "--MORE--" prompt for paginated output. + """ + bindings = KeyBindings() + + @bindings.add("enter") + @bindings.add("down") + def next_line(event: KeyPressEvent) -> None: + event.app.exit(result=PagerResult.NEXT_LINE) + + @bindings.add("space") + def next_page(event: KeyPressEvent) -> None: + event.app.exit(result=PagerResult.NEXT_PAGE) + + @bindings.add("a") + def print_all(event: KeyPressEvent) -> None: + event.app.exit(result=PagerResult.PRINT_ALL) + + @bindings.add("q") + @bindings.add("c-c") + @bindings.add("c-d") + @bindings.add("escape", eager=True) + def no(event: KeyPressEvent) -> None: + event.app.exit(result=PagerResult.ABORT) + + @bindings.add("") + def _(event: KeyPressEvent) -> None: + "Disallow inserting other text." + pass + + style + + session: PromptSession[PagerResult] = PromptSession( + merge_formatted_text( + [ + title, + HTML( + "" + " -- MORE -- " + "[Enter] Scroll " + "[Space] Next page " + "[a] Print all " + "[q] Quit " + ": " + ), + ] + ), + key_bindings=bindings, + erase_when_done=True, + style=style, + ) + return session + + +def _lex_python_result(result: str) -> Generator[tuple[str, str], None, None]: + "Return token list for Python string." + lexer = PythonLexer() + # Use `get_tokens_unprocessed`, so that we get exactly the same string, + # without line endings appended. `print_formatted_text` already appends a + # line ending, and otherwise we'll have two line endings. + tokens = lexer.get_tokens_unprocessed(result) + + for index, tokentype, text in tokens: + yield ("class:" + pygments_token_to_classname(tokentype), text) diff --git a/ptpython/repl.py b/ptpython/repl.py index ce92c66..2d341cb 100644 --- a/ptpython/repl.py +++ b/ptpython/repl.py @@ -17,33 +17,17 @@ import types import warnings from dis import COMPILER_FLAG_NAMES -from enum import Enum from typing import Any, Callable, ContextManager -from prompt_toolkit.formatted_text import ( - HTML, - AnyFormattedText, - FormattedText, - PygmentsTokens, - StyleAndTextTuples, - fragment_list_width, - merge_formatted_text, - to_formatted_text, -) -from prompt_toolkit.formatted_text.utils import fragment_list_to_text, split_lines -from prompt_toolkit.key_binding import KeyBindings, KeyPressEvent from prompt_toolkit.patch_stdout import patch_stdout as patch_stdout_context from prompt_toolkit.shortcuts import ( - PromptSession, clear_title, - print_formatted_text, set_title, ) -from prompt_toolkit.styles import BaseStyle -from prompt_toolkit.utils import DummyContext, get_cwidth -from pygments.lexers import PythonLexer, PythonTracebackLexer -from pygments.token import Token +from prompt_toolkit.utils import DummyContext +from pygments.lexers import PythonTracebackLexer # noqa: F401 +from .printer import OutputPrinter from .python_input import PythonInput PyCF_ALLOW_TOP_LEVEL_AWAIT: int @@ -108,7 +92,15 @@ def run_and_show_expression(self, expression: str) -> None: else: # Print. if result is not None: - self.show_result(result) + self._get_output_printer().display_result( + result=result, + out_prompt=self.get_output_prompt(), + reformat=self.enable_output_formatting, + highlight=self.enable_syntax_highlighting, + paginate=self.enable_pager, + ) + if self.insert_blank_line_after_output: + self.app.output.write("\n") # Loop. self.current_statement_index += 1 @@ -123,6 +115,14 @@ def run_and_show_expression(self, expression: str) -> None: # any case.) self._handle_keyboard_interrupt(e) + def _get_output_printer(self) -> OutputPrinter: + return OutputPrinter( + output=self.app.output, + style=self._current_style, + style_transformation=self.style_transformation, + title=self.title, + ) + def run(self) -> None: """ Run the REPL loop. @@ -318,264 +318,12 @@ def _compile_with_flags(self, code: str, mode: str): dont_inherit=True, ) - def _format_result_output(self, result: object) -> StyleAndTextTuples: - """ - Format __repr__ for an `eval` result. - - Note: this can raise `KeyboardInterrupt` if either calling `__repr__`, - `__pt_repr__` or formatting the output with "Black" takes to long - and the user presses Control-C. - """ - out_prompt = to_formatted_text(self.get_output_prompt()) - - # If the repr is valid Python code, use the Pygments lexer. - try: - result_repr = repr(result) - except KeyboardInterrupt: - raise # Don't catch here. - except BaseException as e: - # Calling repr failed. - self._handle_exception(e) - return [] - - try: - compile(result_repr, "", "eval") - except SyntaxError: - formatted_result_repr = to_formatted_text(result_repr) - else: - # Syntactically correct. Format with black and syntax highlight. - if self.enable_output_formatting: - # Inline import. Slightly speed up start-up time if black is - # not used. - try: - import black - - if not hasattr(black, "Mode"): - raise ImportError - except ImportError: - pass # no Black package in your installation - else: - result_repr = black.format_str( - result_repr, - mode=black.Mode(line_length=self.app.output.get_size().columns), - ) - - formatted_result_repr = to_formatted_text( - PygmentsTokens(list(_lex_python_result(result_repr))) - ) - - # If __pt_repr__ is present, take this. This can return prompt_toolkit - # formatted text. - try: - if hasattr(result, "__pt_repr__"): - formatted_result_repr = to_formatted_text( - getattr(result, "__pt_repr__")() - ) - if isinstance(formatted_result_repr, list): - formatted_result_repr = FormattedText(formatted_result_repr) - except KeyboardInterrupt: - raise # Don't catch here. - except: - # For bad code, `__getattr__` can raise something that's not an - # `AttributeError`. This happens already when calling `hasattr()`. - pass - - # Align every line to the prompt. - line_sep = "\n" + " " * fragment_list_width(out_prompt) - indented_repr: StyleAndTextTuples = [] - - lines = list(split_lines(formatted_result_repr)) - - for i, fragment in enumerate(lines): - indented_repr.extend(fragment) - - # Add indentation separator between lines, not after the last line. - if i != len(lines) - 1: - indented_repr.append(("", line_sep)) - - # Write output tokens. - if self.enable_syntax_highlighting: - formatted_output = merge_formatted_text([out_prompt, indented_repr]) - else: - formatted_output = FormattedText( - out_prompt + [("", fragment_list_to_text(formatted_result_repr))] - ) - - return to_formatted_text(formatted_output) - - def show_result(self, result: object) -> None: - """ - Show __repr__ for an `eval` result and print to output. - """ - formatted_text_output = self._format_result_output(result) - - if self.enable_pager: - self.print_paginated_formatted_text(formatted_text_output) - else: - self.print_formatted_text(formatted_text_output) - - self.app.output.flush() - - if self.insert_blank_line_after_output: - self.app.output.write("\n") - - def print_formatted_text( - self, formatted_text: StyleAndTextTuples, end: str = "\n" - ) -> None: - print_formatted_text( - FormattedText(formatted_text), - style=self._current_style, - style_transformation=self.style_transformation, - include_default_pygments_style=False, - output=self.app.output, - end=end, - ) - - def print_paginated_formatted_text( - self, - formatted_text: StyleAndTextTuples, - end: str = "\n", - ) -> None: - """ - Print formatted text, using --MORE-- style pagination. - (Avoid filling up the terminal's scrollback buffer.) - """ - pager_prompt = self.create_pager_prompt() - size = self.app.output.get_size() - - abort = False - print_all = False - - # Max number of lines allowed in the buffer before painting. - max_rows = size.rows - 1 - - # Page buffer. - rows_in_buffer = 0 - columns_in_buffer = 0 - page: StyleAndTextTuples = [] - - def flush_page() -> None: - nonlocal page, columns_in_buffer, rows_in_buffer - self.print_formatted_text(page, end="") - page = [] - columns_in_buffer = 0 - rows_in_buffer = 0 - - def show_pager() -> None: - nonlocal abort, max_rows, print_all - - # Run pager prompt in another thread. - # Same as for the input. This prevents issues with nested event - # loops. - pager_result = pager_prompt.prompt(in_thread=True) - - if pager_result == PagerResult.ABORT: - print("...") - abort = True - - elif pager_result == PagerResult.NEXT_LINE: - max_rows = 1 - - elif pager_result == PagerResult.NEXT_PAGE: - max_rows = size.rows - 1 - - elif pager_result == PagerResult.PRINT_ALL: - print_all = True - - # Loop over lines. Show --MORE-- prompt when page is filled. - - formatted_text = formatted_text + [("", end)] - lines = list(split_lines(formatted_text)) - - for lineno, line in enumerate(lines): - for style, text, *_ in line: - for c in text: - width = get_cwidth(c) - - # (Soft) wrap line if it doesn't fit. - if columns_in_buffer + width > size.columns: - # Show pager first if we get too many lines after - # wrapping. - if rows_in_buffer + 1 >= max_rows and not print_all: - page.append(("", "\n")) - flush_page() - show_pager() - if abort: - return - - rows_in_buffer += 1 - columns_in_buffer = 0 - - columns_in_buffer += width - page.append((style, c)) - - if rows_in_buffer + 1 >= max_rows and not print_all: - page.append(("", "\n")) - flush_page() - show_pager() - if abort: - return - else: - # Add line ending between lines (if `end="\n"` was given, one - # more empty line is added in `split_lines` automatically to - # take care of the final line ending). - if lineno != len(lines) - 1: - page.append(("", "\n")) - rows_in_buffer += 1 - columns_in_buffer = 0 - - flush_page() - - def create_pager_prompt(self) -> PromptSession[PagerResult]: - """ - Create pager --MORE-- prompt. - """ - return create_pager_prompt(self._current_style, self.title) - - def _format_exception_output(self, e: BaseException) -> PygmentsTokens: - # Instead of just calling ``traceback.format_exc``, we take the - # traceback and skip the bottom calls of this framework. - t, v, tb = sys.exc_info() - - # Required for pdb.post_mortem() to work. - sys.last_type, sys.last_value, sys.last_traceback = t, v, tb - - tblist = list(traceback.extract_tb(tb)) - - for line_nr, tb_tuple in enumerate(tblist): - if tb_tuple[0] == "": - tblist = tblist[line_nr:] - break - - tb_list = traceback.format_list(tblist) - if tb_list: - tb_list.insert(0, "Traceback (most recent call last):\n") - tb_list.extend(traceback.format_exception_only(t, v)) - - tb_str = "".join(tb_list) - - # Format exception and write to output. - # (We use the default style. Most other styles result - # in unreadable colors for the traceback.) - if self.enable_syntax_highlighting: - tokens = list(_lex_python_traceback(tb_str)) - else: - tokens = [(Token, tb_str)] - return PygmentsTokens(tokens) - def _handle_exception(self, e: BaseException) -> None: - output = self.app.output - - tokens = self._format_exception_output(e) - - print_formatted_text( - tokens, - style=self._current_style, - style_transformation=self.style_transformation, - include_default_pygments_style=False, - output=output, + self._get_output_printer().display_exception( + e, + highlight=self.enable_syntax_highlighting, + paginate=self.enable_pager, ) - output.flush() def _handle_keyboard_interrupt(self, e: KeyboardInterrupt) -> None: output = self.app.output @@ -603,22 +351,6 @@ def _remove_from_namespace(self) -> None: del globals["get_ptpython"] -def _lex_python_traceback(tb): - "Return token list for traceback string." - lexer = PythonTracebackLexer() - return lexer.get_tokens(tb) - - -def _lex_python_result(tb): - "Return token list for Python string." - lexer = PythonLexer() - # Use `get_tokens_unprocessed`, so that we get exactly the same string, - # without line endings appended. `print_formatted_text` already appends a - # line ending, and otherwise we'll have two line endings. - tokens = lexer.get_tokens_unprocessed(tb) - return [(tokentype, value) for index, tokentype, value in tokens] - - def enable_deprecation_warnings() -> None: """ Show deprecation warnings, when they are triggered directly by actions in @@ -746,67 +478,3 @@ async def coroutine() -> None: else: with patch_context: repl.run() - - -class PagerResult(Enum): - ABORT = "ABORT" - NEXT_LINE = "NEXT_LINE" - NEXT_PAGE = "NEXT_PAGE" - PRINT_ALL = "PRINT_ALL" - - -def create_pager_prompt( - style: BaseStyle, title: AnyFormattedText = "" -) -> PromptSession[PagerResult]: - """ - Create a "continue" prompt for paginated output. - """ - bindings = KeyBindings() - - @bindings.add("enter") - @bindings.add("down") - def next_line(event: KeyPressEvent) -> None: - event.app.exit(result=PagerResult.NEXT_LINE) - - @bindings.add("space") - def next_page(event: KeyPressEvent) -> None: - event.app.exit(result=PagerResult.NEXT_PAGE) - - @bindings.add("a") - def print_all(event: KeyPressEvent) -> None: - event.app.exit(result=PagerResult.PRINT_ALL) - - @bindings.add("q") - @bindings.add("c-c") - @bindings.add("c-d") - @bindings.add("escape", eager=True) - def no(event: KeyPressEvent) -> None: - event.app.exit(result=PagerResult.ABORT) - - @bindings.add("") - def _(event: KeyPressEvent) -> None: - "Disallow inserting other text." - pass - - style - - session: PromptSession[PagerResult] = PromptSession( - merge_formatted_text( - [ - title, - HTML( - "" - " -- MORE -- " - "[Enter] Scroll " - "[Space] Next page " - "[a] Print all " - "[q] Quit " - ": " - ), - ] - ), - key_bindings=bindings, - erase_when_done=True, - style=style, - ) - return session