diff --git a/tests/test_api.py b/tests/test_api.py index 1acc26f..251331a 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -242,6 +242,33 @@ def test_pathlib(): assert test_str == toml.dumps(o, encoder=toml.TomlPathlibEncoder()) +def test_dataclass(): + if (3, 7) <= sys.version_info: + from dataclasses import dataclass, asdict + + @dataclass + class TestDataClassIn(): + a: int + + @dataclass + class TestDataClass(): + a: int + c: str + nested: TestDataClassIn + nested_arr: list + + dc = TestDataClass( + a=1, c="ccc", + nested=TestDataClassIn(a=1), + nested_arr=[TestDataClassIn(a=10), + TestDataClassIn(a=100)] + ) + o = {"dc_on_root": dc, "dcarr_on_root": [TestDataClassIn(a=-1), TestDataClassIn(a=-2)]} + o_expected = {"dc_on_root": asdict(dc), "dcarr_on_root": [asdict(v) for v in o["dcarr_on_root"]]} + + assert o_expected == toml.loads(toml.dumps(o, encoder=toml.TomlDataclassEncoder())) + + def test_comment_preserve_decoder_encoder(): test_str = """[[products]] name = "Nail" diff --git a/toml/__init__.py b/toml/__init__.py index 7719ac2..3614eac 100644 --- a/toml/__init__.py +++ b/toml/__init__.py @@ -23,3 +23,4 @@ TomlNumpyEncoder = encoder.TomlNumpyEncoder TomlPreserveCommentEncoder = encoder.TomlPreserveCommentEncoder TomlPathlibEncoder = encoder.TomlPathlibEncoder +TomlDataclassEncoder = encoder.TomlDataclassEncoder diff --git a/toml/encoder.py b/toml/encoder.py index bf17a72..b4b2e59 100644 --- a/toml/encoder.py +++ b/toml/encoder.py @@ -179,6 +179,9 @@ def dump_value(self, v): # Evaluate function (if it exists) else return v return dump_fn(v) if dump_fn is not None else self.dump_funcs[str](v) + def _preprocess_section(self, contents): + return contents + def dump_sections(self, o, sup): retstr = "" if sup != "" and sup[-1] != ".": @@ -188,16 +191,17 @@ def dump_sections(self, o, sup): for section in o: section = unicode(section) qsection = section + contents = self._preprocess_section(o[section]) if not re.match(r'^[A-Za-z0-9_-]+$', section): qsection = _dump_str(section) - if not isinstance(o[section], dict): + if not isinstance(contents, dict): arrayoftables = False - if isinstance(o[section], list): - for a in o[section]: + if isinstance(contents, list): + for a in contents: if isinstance(a, dict): arrayoftables = True if arrayoftables: - for a in o[section]: + for a in contents: arraytabstr = "\n" arraystr += "[[" + sup + qsection + "]]\n" s, d = self.dump_sections(a, sup + qsection) @@ -221,14 +225,14 @@ def dump_sections(self, o, sup): d = newd arraystr += arraytabstr else: - if o[section] is not None: + if contents is not None: retstr += (qsection + " = " + - unicode(self.dump_value(o[section])) + '\n') - elif self.preserve and isinstance(o[section], InlineTableDict): + unicode(self.dump_value(contents)) + '\n') + elif self.preserve and isinstance(contents, InlineTableDict): retstr += (qsection + " = " + - self.dump_inline_table(o[section])) + self.dump_inline_table(contents)) else: - retdict[qsection] = o[section] + retdict[qsection] = contents retstr += arraystr return (retstr, retdict) @@ -302,3 +306,19 @@ def dump_value(self, v): if isinstance(v, pathlib.PurePath): v = str(v) return super(TomlPathlibEncoder, self).dump_value(v) + +class TomlDataclassEncoder(TomlEncoder): + + def _preprocess_section(self, contents): + if (3, 7) <= sys.version_info: + import dataclasses + if dataclasses.is_dataclass(contents): + contents = dataclasses.asdict(contents) + elif isinstance(contents, list): + contents, _contents = [], contents + for c in _contents: + if dataclasses.is_dataclass(c): + c = dataclasses.asdict(c) + contents.append(c) + + return super(TomlDataclassEncoder, self)._preprocess_section(contents) \ No newline at end of file diff --git a/toml/encoder.pyi b/toml/encoder.pyi index 194a358..584234c 100644 --- a/toml/encoder.pyi +++ b/toml/encoder.pyi @@ -32,3 +32,6 @@ class TomlPreserveCommentEncoder(TomlEncoder): class TomlPathlibEncoder(TomlEncoder): def dump_value(self, v: Any): ... + +class TomlDataclassEncoder(TomlEncoder): + def dump_value(self, v: Any): ...