Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new encoder for dataclass #412

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,33 @@ def test_pathlib():
assert test_str == toml.dumps(o, encoder=toml.TomlPathlibEncoder())


def test_dataclass():
if (3, 7) <= sys.version_info:
from dataclasses import dataclass, asdict

@dataclass
class TestDataClassIn():
a: int

@dataclass
class TestDataClass():
a: int
c: str
nested: TestDataClassIn
nested_arr: list

dc = TestDataClass(
a=1, c="ccc",
nested=TestDataClassIn(a=1),
nested_arr=[TestDataClassIn(a=10),
TestDataClassIn(a=100)]
)
o = {"dc_on_root": dc, "dcarr_on_root": [TestDataClassIn(a=-1), TestDataClassIn(a=-2)]}
o_expected = {"dc_on_root": asdict(dc), "dcarr_on_root": [asdict(v) for v in o["dcarr_on_root"]]}

assert o_expected == toml.loads(toml.dumps(o, encoder=toml.TomlDataclassEncoder()))


def test_comment_preserve_decoder_encoder():
test_str = """[[products]]
name = "Nail"
Expand Down
1 change: 1 addition & 0 deletions toml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@
TomlNumpyEncoder = encoder.TomlNumpyEncoder
TomlPreserveCommentEncoder = encoder.TomlPreserveCommentEncoder
TomlPathlibEncoder = encoder.TomlPathlibEncoder
TomlDataclassEncoder = encoder.TomlDataclassEncoder
38 changes: 29 additions & 9 deletions toml/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ def dump_value(self, v):
# Evaluate function (if it exists) else return v
return dump_fn(v) if dump_fn is not None else self.dump_funcs[str](v)

def _preprocess_section(self, contents):
return contents

def dump_sections(self, o, sup):
retstr = ""
if sup != "" and sup[-1] != ".":
Expand All @@ -188,16 +191,17 @@ def dump_sections(self, o, sup):
for section in o:
section = unicode(section)
qsection = section
contents = self._preprocess_section(o[section])
if not re.match(r'^[A-Za-z0-9_-]+$', section):
qsection = _dump_str(section)
if not isinstance(o[section], dict):
if not isinstance(contents, dict):
arrayoftables = False
if isinstance(o[section], list):
for a in o[section]:
if isinstance(contents, list):
for a in contents:
if isinstance(a, dict):
arrayoftables = True
if arrayoftables:
for a in o[section]:
for a in contents:
arraytabstr = "\n"
arraystr += "[[" + sup + qsection + "]]\n"
s, d = self.dump_sections(a, sup + qsection)
Expand All @@ -221,14 +225,14 @@ def dump_sections(self, o, sup):
d = newd
arraystr += arraytabstr
else:
if o[section] is not None:
if contents is not None:
retstr += (qsection + " = " +
unicode(self.dump_value(o[section])) + '\n')
elif self.preserve and isinstance(o[section], InlineTableDict):
unicode(self.dump_value(contents)) + '\n')
elif self.preserve and isinstance(contents, InlineTableDict):
retstr += (qsection + " = " +
self.dump_inline_table(o[section]))
self.dump_inline_table(contents))
else:
retdict[qsection] = o[section]
retdict[qsection] = contents
retstr += arraystr
return (retstr, retdict)

Expand Down Expand Up @@ -302,3 +306,19 @@ def dump_value(self, v):
if isinstance(v, pathlib.PurePath):
v = str(v)
return super(TomlPathlibEncoder, self).dump_value(v)

class TomlDataclassEncoder(TomlEncoder):

def _preprocess_section(self, contents):
if (3, 7) <= sys.version_info:
import dataclasses
if dataclasses.is_dataclass(contents):
contents = dataclasses.asdict(contents)
elif isinstance(contents, list):
contents, _contents = [], contents
for c in _contents:
if dataclasses.is_dataclass(c):
c = dataclasses.asdict(c)
contents.append(c)

return super(TomlDataclassEncoder, self)._preprocess_section(contents)
3 changes: 3 additions & 0 deletions toml/encoder.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,6 @@ class TomlPreserveCommentEncoder(TomlEncoder):

class TomlPathlibEncoder(TomlEncoder):
def dump_value(self, v: Any): ...

class TomlDataclassEncoder(TomlEncoder):
def dump_value(self, v: Any): ...