Skip to content

Commit

Permalink
feat: preserve tuples in JSON
Browse files Browse the repository at this point in the history
  • Loading branch information
nkemnitz committed Sep 19, 2023
1 parent 4fdfaba commit 42da668
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 10 deletions.
3 changes: 1 addition & 2 deletions tests/unit/cli/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# pylint: disable=unused-argument,redefined-outer-name
import json

import pytest
from click.testing import CliRunner

from zetta_utils import builder, cli
from zetta_utils.parsing import json


@pytest.fixture
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/parsing/test_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from io import StringIO

from zetta_utils.parsing import json


def test_str_roundtrip():
assert json.loads(json.dumps([1, "2"])) == [1, "2"]


def test_str_tuple_roundtrip():
assert json.loads(json.dumps((1, "2"))) == (1, "2")


def test_str_nested_tuple_roundtrip():
assert json.loads(json.dumps((1, "2", (3, "4")))) == (1, "2", (3, "4"))


def test_fp_roundtrip():
fp = StringIO()
json.dump([1, "2"], fp)
fp.seek(0)
assert json.load(fp) == [1, "2"]


def test_fp_tuple_roundtrip():
fp = StringIO()
json.dump((1, "2"), fp)
fp.seek(0)
assert json.load(fp) == (1, "2")


def test_fp_nested_tuple_roundtrip():
fp = StringIO()
json.dump((1, "2", (3, "4")), fp)
fp.seek(0)
assert json.load(fp) == (1, "2", (3, "4"))
3 changes: 1 addition & 2 deletions zetta_utils/builder/build.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Bulding objects from nested specs."""
from __future__ import annotations

import json
from typing import Any, Callable, Final, Optional

import attrs
from typeguard import typechecked

from zetta_utils import parsing
from zetta_utils.common import ctx_managers
from zetta_utils.parsing import json

from . import constants
from .registry import get_matching_entry
Expand Down Expand Up @@ -183,7 +183,6 @@ def get_display_name(self): # pragma: no cover # pretty print

def _get_built_spec_kwargs(self, version: str) -> dict[str, Any]:
if self._built_spec_kwargs is None:

self._built_spec_kwargs = {
k: _traverse_spec(
v,
Expand Down
2 changes: 1 addition & 1 deletion zetta_utils/cli/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import os
import pprint
from typing import Optional
Expand All @@ -7,6 +6,7 @@

import zetta_utils
from zetta_utils import log
from zetta_utils.parsing import json

logger = log.get_logger("zetta_utils")

Expand Down
2 changes: 1 addition & 1 deletion zetta_utils/cloud_management/execution_tracker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import os
import time
from contextlib import contextmanager
Expand All @@ -14,6 +13,7 @@
from zetta_utils.layer.db_layer import DBRowDataT, build_db_layer
from zetta_utils.layer.db_layer.datastore import DatastoreBackend
from zetta_utils.log import get_logger
from zetta_utils.parsing import json

from .resource_allocation.k8s import ClusterInfo

Expand Down
1 change: 1 addition & 0 deletions zetta_utils/parsing/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from . import cue
from . import ngl_state
from . import json
2 changes: 1 addition & 1 deletion zetta_utils/parsing/cue.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""cuelang parsing."""
import json
import os
import pathlib
import subprocess
Expand All @@ -8,6 +7,7 @@
import fsspec

from zetta_utils import log
from zetta_utils.parsing import json

logger = log.get_logger("zetta_utils")

Expand Down
50 changes: 50 additions & 0 deletions zetta_utils/parsing/json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from __future__ import annotations

import json
from collections.abc import Iterator
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from _typeshed import SupportsRead, SupportsWrite

Check warning on line 8 in zetta_utils/parsing/json.py

View check run for this annotation

Codecov / codecov/patch

zetta_utils/parsing/json.py#L8

Added line #L8 was not covered by tests


def _mark_python_types(obj: Any) -> Any:
if isinstance(obj, tuple):
return {"__tuple__": [_mark_python_types(e) for e in obj]}
if isinstance(obj, list):
return [_mark_python_types(e) for e in obj]
if isinstance(obj, dict):
return {key: _mark_python_types(value) for key, value in obj.items()}
else:
return obj


class ZettaSpecJSONEncoder(json.JSONEncoder):
def encode(self, o: Any) -> str:
return super().encode(_mark_python_types(o))

def iterencode(self, o: Any, _one_shot: bool = False) -> Iterator[str]:
return super().iterencode(_mark_python_types(o), _one_shot=_one_shot)


def tuple_hook(obj):
if "__tuple__" in obj:
return tuple(obj["__tuple__"])
else:
return obj


def dumps(obj, **kwargs) -> str:
return json.dumps(obj, cls=ZettaSpecJSONEncoder, **kwargs)


def dump(obj: Any, fp: SupportsWrite[str], **kwargs) -> None:
json.dump(obj, fp, cls=ZettaSpecJSONEncoder, **kwargs)


def loads(s: str, **kwargs) -> Any:
return json.loads(s, object_hook=tuple_hook, **kwargs)


def load(fp: SupportsRead[str | bytes], **kwargs) -> Any:
return json.load(fp, object_hook=tuple_hook, **kwargs)
2 changes: 1 addition & 1 deletion zetta_utils/training/lightning/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import json
import os
from contextlib import ExitStack
from typing import Any, Dict, Final, List, Optional
Expand All @@ -16,6 +15,7 @@
from zetta_utils import builder, load_all_modules, log, mazepa, parsing
from zetta_utils.builder.build import BuilderPartial
from zetta_utils.cloud_management import execution_tracker, resource_allocation
from zetta_utils.parsing import json

logger = log.get_logger("zetta_utils")

Expand Down
3 changes: 1 addition & 2 deletions zetta_utils/training/lightning/trainers/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import datetime
import importlib.metadata
import json
import os
from typing import Any, Dict, List, Optional

Expand All @@ -15,6 +14,7 @@
from pytorch_lightning.strategies import ddp

from zetta_utils import builder, log
from zetta_utils.parsing import json

logger = log.get_logger("zetta_utils")
ONNX_OPSET_VERSION = 17
Expand Down Expand Up @@ -126,7 +126,6 @@ def log_config(config):
def save_checkpoint(
self, filepath, weights_only: bool = False, storage_options: Optional[Any] = None
): # pylint: disable=too-many-locals

if filepath.startswith("./"):
filepath = f"{self.default_root_dir}/{filepath[2:]}"
super().save_checkpoint(filepath, weights_only, storage_options)
Expand Down

0 comments on commit 42da668

Please sign in to comment.