diff --git a/mdformat_myst/plugin.py b/mdformat_myst/plugin.py index 58c9622..31512e4 100644 --- a/mdformat_myst/plugin.py +++ b/mdformat_myst/plugin.py @@ -1,7 +1,9 @@ from __future__ import annotations +import argparse import re from textwrap import indent +from typing import Set from markdown_it import MarkdownIt import mdformat.plugins @@ -16,8 +18,35 @@ _TARGET_PATTERN = re.compile(r"^\s*\(([a-zA-Z0-9|@<>*./_\-+:]{1,100})\)=\s*$") _ROLE_NAME_PATTERN = re.compile(r"({[a-zA-Z0-9_\-+:]{1,36}})") +# TODO eventually support all in: +# https://myst-parser.readthedocs.io/en/latest/syntax/optional.html +SUPPORTED_EXTENSIONS = {"dollarmath"} + + +def _extension_arg_type(arg: str) -> Set[str]: + """Convert extension list to set and validate.""" + extensions = set(arg.split(",")) + unallowed_extensions = extensions.difference(SUPPORTED_EXTENSIONS) + if unallowed_extensions: + raise ValueError(f"Unsupported myst extensions: {unallowed_extensions!r}") + return extensions + + +def add_cli_options(parser: argparse.ArgumentParser) -> None: + """Add options to the mdformat CLI, to be stored in + ``mdit.options["mdformat"]``""" + parser.add_argument( + "--myst-extensions", + dest="myst_extensions", + type=_extension_arg_type, + default="dollarmath", + help="Comma-delimited list of MyST syntax extensions", + metavar="|".join(SUPPORTED_EXTENSIONS), + ) + def update_mdit(mdit: MarkdownIt) -> None: + myst_extensions = mdit.options.get("mdformat", {}).get("myst_extensions", set()) # Enable mdformat-tables plugin tables_plugin = mdformat.plugins.PARSER_EXTENSIONS["tables"] if tables_plugin not in mdit.options["parser_extension"]: @@ -38,7 +67,8 @@ def update_mdit(mdit: MarkdownIt) -> None: mdit.use(myst_block_plugin) # Enable dollarmath markdown-it extension - mdit.use(dollarmath_plugin) + if "dollarmath" in myst_extensions: + mdit.use(dollarmath_plugin) # Enable footnote markdown-it extension mdit.use(footnote_plugin) diff --git a/tests/test_mdformat_myst.py b/tests/test_mdformat_myst.py index 51fffb7..1dc8165 100644 --- a/tests/test_mdformat_myst.py +++ b/tests/test_mdformat_myst.py @@ -13,7 +13,9 @@ ) def test_fixtures__api(line, title, text, expected): """Test fixtures in tests/data/fixtures.md.""" - md_new = mdformat.text(text, extensions={"myst"}) + md_new = mdformat.text( + text, extensions={"myst"}, options={"myst_extensions": {"dollarmath"}} + ) try: assert md_new == expected except Exception: @@ -21,6 +23,14 @@ def test_fixtures__api(line, title, text, expected): raise +def test_cli_unsupported_extension(tmp_path): + file_path = tmp_path / "test_markdown.md" + file_path.write_text("a") + + with pytest.raises(SystemExit): + mdformat._cli.run([str(file_path), "--myst-extensions", "a"]) + + @pytest.mark.parametrize( "line,title,text,expected", TEST_CASES, ids=[f[1] for f in TEST_CASES] ) @@ -28,6 +38,6 @@ def test_fixtures__cli(line, title, text, expected, tmp_path): """Test fixtures in tests/data/fixtures.md.""" file_path = tmp_path / "test_markdown.md" file_path.write_text(text) - assert mdformat._cli.run([str(file_path)]) == 0 + assert mdformat._cli.run([str(file_path), "--myst-extensions", "dollarmath"]) == 0 md_new = file_path.read_text() assert md_new == expected