diff --git a/src/mdformat/renderer/_context.py b/src/mdformat/renderer/_context.py index 59ce666..a4f623e 100644 --- a/src/mdformat/renderer/_context.py +++ b/src/mdformat/renderer/_context.py @@ -151,8 +151,8 @@ def fence(node: RenderTreeNode, context: RenderContext) -> str: fence_char = "~" if "`" in info_str else "`" # Format the code block using enabled codeformatter funcs - if lang in context.options.get("codeformatters", {}): - fmt_func = context.options["codeformatters"][lang] + fmt_func = context.options.get("codeformatters", {}).get(lang) + if fmt_func: try: code_block = fmt_func(code_block, info_str) except Exception: @@ -167,6 +167,9 @@ def fence(node: RenderTreeNode, context: RenderContext) -> str: if filename: warn_msg += f". Filename: {filename}" LOGGER.warning(warn_msg) + else: + if code_block and code_block[-1] != "\n": + code_block += "\n" # The code block must not include as long or longer sequence of `fence_char`s # as the fence string itself diff --git a/tests/test_plugins.py b/tests/test_plugins.py index 27c4e81..b7c0f8e 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -12,12 +12,11 @@ from mdformat.renderer import MDRenderer, RenderContext, RenderTreeNode -def example_formatter(code, info): - return "dummy\n" - - def test_code_formatter(monkeypatch): - monkeypatch.setitem(CODEFORMATTERS, "lang", example_formatter) + def fmt_func(code, info): + return "dummy\n" + + monkeypatch.setitem(CODEFORMATTERS, "lang", fmt_func) text = mdformat.text( dedent( """\ @@ -37,6 +36,82 @@ def test_code_formatter(monkeypatch): ) +def test_code_formatter__empty_str(monkeypatch): + def fmt_func(code, info): + return "" + + monkeypatch.setitem(CODEFORMATTERS, "lang", fmt_func) + text = mdformat.text( + dedent( + """\ + ~~~lang + aag + gw + ~~~ + """ + ), + codeformatters={"lang"}, + ) + assert text == dedent( + """\ + ```lang + ``` + """ + ) + + +def test_code_formatter__no_end_newline(monkeypatch): + def fmt_func(code, info): + return "dummy\ndum" + + monkeypatch.setitem(CODEFORMATTERS, "lang", fmt_func) + text = mdformat.text( + dedent( + """\ + ```lang + ``` + """ + ), + codeformatters={"lang"}, + ) + assert text == dedent( + """\ + ```lang + dummy + dum + ``` + """ + ) + + +def test_code_formatter__interface(monkeypatch): + def fmt_func(code, info): + return info + code * 2 + + monkeypatch.setitem(CODEFORMATTERS, "lang", fmt_func) + text = mdformat.text( + dedent( + """\ + ``` lang long + multi + mul + ``` + """ + ), + codeformatters={"lang"}, + ) + assert text == dedent( + """\ + ```lang long + lang longmulti + mul + multi + mul + ``` + """ + ) + + class TextEditorPlugin: """A plugin that makes all text the same."""