Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Move DEFAULT_DATA_PROCESSORS #216

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion qualang_tools/results/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from qualang_tools.results.results import fetching_tool
from qualang_tools.results.results import progress_counter
from qualang_tools.results.results import wait_until_job_is_paused

from qualang_tools.results.data_handler import DataHandler, data_processors

__all__ = ["fetching_tool", "progress_counter", "wait_until_job_is_paused", "DataHandler", "data_processors"]
39 changes: 35 additions & 4 deletions qualang_tools/results/data_handler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,37 @@
from .data_folder_tools import *
from . import data_processors
from .data_processors import DEFAULT_DATA_PROCESSORS
from .data_handler import *

__all__ = [*data_folder_tools.__all__, data_processors, DEFAULT_DATA_PROCESSORS, *data_handler.__all__]
DEFAULT_DATA_PROCESSORS = [
data_processors.MatplotlibPlotSaver,
data_processors.NumpyArraySaver,
]

try:
import xarray # noqa: F401

DEFAULT_DATA_PROCESSORS.append(data_processors.XarraySaver)
except ImportError:
pass

from .data_folder_tools import ( # noqa: E402
DEFAULT_FOLDER_PATTERN,
extract_data_folder_properties,
get_latest_data_folder,
create_data_folder,
)
from .data_handler import save_data, DataHandler # noqa: E402
from .data_processors import DataProcessor, MatplotlibPlotSaver, NumpyArraySaver, XarraySaver # noqa: E402


__all__ = [
"DEFAULT_FOLDER_PATTERN",
"extract_data_folder_properties",
"get_latest_data_folder",
"create_data_folder",
"DataProcessor",
"MatplotlibPlotSaver",
"NumpyArraySaver",
"XarraySaver",
"DEFAULT_DATA_PROCESSORS",
"save_data",
"DataHandler",
]
3 changes: 0 additions & 3 deletions qualang_tools/results/data_handler/data_folder_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
from datetime import datetime


__all__ = ["DEFAULT_FOLDER_PATTERN", "extract_data_folder_properties", "get_latest_data_folder", "create_data_folder"]


DEFAULT_FOLDER_PATTERN = "%Y-%m-%d/#{idx}_{name}_%H%M%S"


Expand Down
8 changes: 4 additions & 4 deletions qualang_tools/results/data_handler/data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, Dict, Optional, Sequence, Union
import warnings

from .data_processors import DEFAULT_DATA_PROCESSORS, DataProcessor
from .data_processors import DataProcessor
from .data_folder_tools import (
DEFAULT_FOLDER_PATTERN,
create_data_folder,
Expand All @@ -14,8 +14,6 @@
)


__all__ = ["save_data", "DataHandler"]

NODE_FILENAME = "node.json"


Expand Down Expand Up @@ -106,7 +104,9 @@ class DataHandler:
data_handler.save_data(data)
"""

default_data_processors = DEFAULT_DATA_PROCESSORS
from . import DEFAULT_DATA_PROCESSORS as _DEFAULT_DATA_PROCESSORS

default_data_processors = _DEFAULT_DATA_PROCESSORS
root_data_folder: Path = None
folder_pattern: str = DEFAULT_FOLDER_PATTERN
data_filename: str = "data.json"
Expand Down
19 changes: 0 additions & 19 deletions qualang_tools/results/data_handler/data_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@
from matplotlib import pyplot as plt
import numpy as np

__all__ = ["DEFAULT_DATA_PROCESSORS", "DataProcessor", "MatplotlibPlotSaver", "NumpyArraySaver", "XarraySaver"]


DEFAULT_DATA_PROCESSORS = []


def iterate_nested_dict(
d: Dict[str, Any], parent_keys: Optional[List[str]] = None
Expand Down Expand Up @@ -78,9 +73,6 @@ def post_process(self, data_folder: Path):
fig.savefig(data_folder / path)


DEFAULT_DATA_PROCESSORS.append(MatplotlibPlotSaver)


class NumpyArraySaver(DataProcessor):
merge_arrays: bool = True
merged_array_name: str = "arrays.npz"
Expand Down Expand Up @@ -118,9 +110,6 @@ def post_process(self, data_folder: Path):
np.save(data_folder / path.with_suffix(".npy"), arr)


DEFAULT_DATA_PROCESSORS.append(NumpyArraySaver)


class XarraySaver(DataProcessor):
merge_arrays: bool = False
merged_array_name: str = "xarrays"
Expand Down Expand Up @@ -184,11 +173,3 @@ def post_process(self, data_folder: Path):
else:
for path, array in self.data_arrays.items():
array.to_netcdf(data_folder / path.with_suffix(self.file_suffix))


try:
import xarray # noqa: F401

DEFAULT_DATA_PROCESSORS.append(XarraySaver)
except ImportError:
pass
2 changes: 1 addition & 1 deletion tests/data_handler/test_numpy_array_saver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np

from qualang_tools.results.data_handler.data_processors import DEFAULT_DATA_PROCESSORS, NumpyArraySaver
from qualang_tools.results.data_handler import NumpyArraySaver


def test_numpy_array_saver_process_merged():
Expand Down
Loading