diff --git a/.coveragerc b/.coveragerc index 569f996..c94e220 100644 --- a/.coveragerc +++ b/.coveragerc @@ -9,7 +9,6 @@ omit = py_asciimath/parser/parse_lists.py py_asciimath/utils/log.py py_asciimath/grammar/* - py_asciimath/parser/* py_asciimath/translation/* */__init__.py source = diff --git a/.gitignore b/.gitignore index 7d1d468..f86b4ce 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ # Exclude docs docs/* +# Exclude all generated .xml files +*.xml + # Exclude main.py main.py diff --git a/MANIFEST.in b/MANIFEST.in index 9e676c9..9cda9d2 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,6 @@ include LICENSE include README.md -recursive-include examples * +graft examples +global-exclude *.py[cod] recursive-include py_asciimath/dtd/ * recursive-include py_asciimath/translation/mathml2tex/ * \ No newline at end of file diff --git a/py_asciimath/parser/parser.py b/py_asciimath/parser/parser.py index d324d66..ac69964 100644 --- a/py_asciimath/parser/parser.py +++ b/py_asciimath/parser/parser.py @@ -11,13 +11,13 @@ class MathMLParser(object): xml_decl_pattern = re.compile( - r"(<\?xml.*?(encoding=(?:'|\")(.*?)(?:'|\"))?\?>)" + r"(\s*)(<\?xml.*?(encoding=(?:'|\")(.*?)(?:'|\"))?\?>)" ) doctype_pattern = re.compile( r"()", re.MULTILINE ) - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs): # pragma: no cover super(MathMLParser, self).__init__() @classmethod @@ -38,13 +38,16 @@ def get_encoding(cls, s): Returns: str: Encoding of the XML document """ - xml_decl_match = re.match(cls.xml_decl_pattern, s) - if xml_decl_match is not None: - if xml_decl_match.span(1)[0] != 0: + xml_decl_match = list(re.finditer(cls.xml_decl_pattern, s)) + if len(xml_decl_match) > 1: + raise Exception("Multiple XML declarations found") + elif len(xml_decl_match) == 1: + xml_decl_match = xml_decl_match[0] + if xml_decl_match.span()[0] != 0: raise Exception( "XML declaration must be at the beginning of the file" ) - encoding = xml_decl_match.group(3) + encoding = xml_decl_match.group(4) logging.info("Encoding from XML declaration: " + encoding) else: encoding = None @@ -82,7 +85,7 @@ def set_doctype(cls, s, network, dtd=None): doctype_match = list(re.finditer(cls.doctype_pattern, s)) if len(doctype_match) > 1: raise Exception("Multiple DOCTYPE declarations found") - if doctype_match != []: + elif doctype_match != []: doctype_match = doctype_match[0] if ( doctype_match.group(2) == "PUBLIC" @@ -111,21 +114,26 @@ def set_doctype(cls, s, network, dtd=None): "no need to bother your ISP" ) else: + dtd = dtd if dtd is not None else "mathml3" logging.warning( "No DTD declaration found: " - "set to local {} DTD".format( - dtd if dtd is not None else "mathml3" + "set to {} {} DTD".format( + "remote" if network else "local", dtd ) ) - xml_decl_match = re.match(cls.xml_decl_pattern, s) - if xml_decl_match is None: - start = 0 - else: - start = xml_decl_match.span(1)[1] - if dtd is not None: - doctype = cls.get_doctype(dtd, network) + xml_decl_match = list(re.finditer(cls.xml_decl_pattern, s)) + if len(xml_decl_match) > 1: + raise Exception("Multiple XML declarations found") + elif len(xml_decl_match) == 1: + xml_decl_match = xml_decl_match[0] + if xml_decl_match.span()[0] != 0: + raise Exception( + "XML declaration must be at the beginning of the file" + ) + start = xml_decl_match.span()[1] else: - doctype = cls.get_doctype("mathml3", network) + start = 0 + doctype = cls.get_doctype(dtd, network) s = s[:start] + doctype + s[start:] return s @@ -211,7 +219,7 @@ def get_parser( ns_clean=True, resolve_entities=False, **kwargs - ): + ): # pragma: no cover """Create a MathML XML parser Args: @@ -226,7 +234,7 @@ def get_parser( **kwargs: Additional ~lxml.extree.XMLParser options Returns: - str: Parsed and possibly validated MathML XML + lxml.etree.XMLParser: MathML parser following the specifications """ return lxml.etree.XMLParser( dtd_validation=dtd_validation, diff --git a/py_asciimath/translator/translator.py b/py_asciimath/translator/translator.py index 236702c..96eb0c9 100644 --- a/py_asciimath/translator/translator.py +++ b/py_asciimath/translator/translator.py @@ -243,15 +243,9 @@ def _translate( dstyle = '{}' else: dstyle = "{}" - if network: # pragma: no cover - if check_connection(): - doctype = MathMLParser.get_doctype(dtd, True) - else: - network = False - doctype = MathMLParser.get_doctype(dtd, False) - logging.warning("No connection available...") - else: - doctype = MathMLParser.get_doctype(dtd, False) + if network and not check_connection(): + network = False + logging.warning("No connection available...") parsed = ( ( '' @@ -273,11 +267,11 @@ def _translate( parsed, dtd, dtd_validation, network, **kwargs ) if output == "string": - encoding = parsed.getroottree().docinfo.encoding + parsed = parsed.getroottree() + encoding = parsed.docinfo.encoding parsed = lxml.etree.tostring( parsed, pretty_print=xml_pprint, - doctype=(doctype if dtd_validation else None), xml_declaration=xml_declaration, encoding=encoding, ).decode(encoding) @@ -365,10 +359,9 @@ def __init__(self): self.transformer = lxml.etree.XSLT(transformer) def _translate(self, exp, network=False, **kwargs): - if network: - if not check_connection(): - network = False - logging.warning("No connection available...") + if network and not check_connection(): + network = False + logging.warning("No connection available...") mml_version = MathMLParser.get_doctype_version(exp) if mml_version == "1": raise NotImplementedError( diff --git a/tests/test_MathMLParser.py b/tests/test_MathMLParser.py new file mode 100644 index 0000000..94f6ca3 --- /dev/null +++ b/tests/test_MathMLParser.py @@ -0,0 +1,135 @@ +import unittest + +from py_asciimath import PROJECT_ROOT +from py_asciimath.parser.parser import MathMLParser + + +class TestMathMLParser(unittest.TestCase): + def setUp(self): + self.maxDiff = None + + def test_mathmlparser_get_doctype(self): + for network in [True, False]: + for mml in ["mathml1", "mathml2", "mathml3"]: + doctype = MathMLParser.get_doctype(mml, network) + if network: + self.assertEqual( + doctype, + ''.format( + mml[-1], mml, mml + ) + if mml != "mathml1" + else "', + ) + else: + self.assertEqual( + doctype, + ''.format(mml, mml), + ) + self.assertRaises( + NotImplementedError, MathMLParser.get_doctype, "a", False + ) + + def test_mathmlparser_get_doctype_version_ok(self): + self.assertEqual( + "3", + MathMLParser.get_doctype_version( + '' + ), + ) + self.assertEqual( + "1", + MathMLParser.get_doctype_version( + '' + ), + ) + self.assertEqual( + None, MathMLParser.get_doctype_version(""), + ) + self.assertRaises( + Exception, + MathMLParser.get_doctype_version, + '\n' + '', + ) + + def test_mathmlparser_get_encoding(self): + self.assertRaises( + Exception, + MathMLParser.get_encoding, + "" + "+", + ) + self.assertRaises( + Exception, + MathMLParser.get_encoding, + "+", + ) + self.assertEqual( + "UTF-8", + MathMLParser.get_encoding( + "+" + ), + ) + self.assertEqual( + None, MathMLParser.get_encoding(""), + ) + + def test_mathmlparser_set_doctype(self): + self.assertRaises( + Exception, + MathMLParser.set_doctype, + "" + + '' + + '', + False + ) + s = ( + "" + + '' + ) + self.assertEqual( + "" + + MathMLParser.get_doctype("mathml3", False), + MathMLParser.set_doctype(s, False), + ) + s = "" + self.assertEqual( + "" + + MathMLParser.get_doctype("mathml1", False), + MathMLParser.set_doctype(s, False, dtd="mathml1"), + ) + s = ( + '' + ) + self.assertEqual(s, MathMLParser.set_doctype(s, True)) + self.assertRaises( + Exception, + MathMLParser.set_doctype, + "" + "", + False + ) + self.assertRaises( + Exception, + MathMLParser.set_doctype, + "+", + False + ) + + +if __name__ == "__main__": + unittest.main()