From 6b684b629209970e2f2bb80e564616eae42e0b49 Mon Sep 17 00:00:00 2001 From: Alex Hadley Date: Fri, 3 Nov 2023 15:45:28 -0700 Subject: [PATCH] #37 Update Logger.convert_to_json() to support Numpy values and Pandas DataFrames --- datalogger/_logger.py | 60 +++++++++++++++++++++++++++++++++---------- 1 file changed, 47 insertions(+), 13 deletions(-) diff --git a/datalogger/_logger.py b/datalogger/_logger.py index 9eadb05..f018df2 100644 --- a/datalogger/_logger.py +++ b/datalogger/_logger.py @@ -3,10 +3,12 @@ from __future__ import annotations from typing import TypeVar, Generic, Any, overload, get_type_hints, get_origin from collections.abc import Callable, Sequence, Collection, Mapping +from abc import ABC, abstractmethod import os import sys -from abc import ABC, abstractmethod from datetime import datetime, timezone +import numpy as np +import pandas as pd # type: ignore from datalogger._variables import Coord, DataVar from datalogger._logs import LogMetadata, DataLog, DictLog from datalogger._get_filename import get_filename @@ -216,39 +218,68 @@ def make_log(log_metadata: LogMetadata) -> DataLog: return self._log(make_log, description, commit_id) @classmethod - def _convert_to_json(cls, obj: Any) -> Any: + def convert_to_json( + cls, obj: Any, convert: Callable[[Any], Any] | None = None + ) -> Any: """ - Return a JSON-serializable version of the given object by converting ``Mapping`` - and ``Collection`` objects to dictionaries and lists, converting other - non-JSON-serializable values to ``repr`` strings, and converting all dictionary - keys to strings. + Return a JSON-serializable version of the given object. This function is used to + convert objects to JSON for :py:meth:`Logger.log_dict` and + :py:meth:`Logger.log_props`. + + 1. If provided, ``convert()`` will be used to convert the object. + + 2. Numpy scalars will be unpacked and Pandas DataFrames will be converted to + dictionaries. + + 3. ``Mapping`` and ``Collection`` objects will be converted to dictionaries and + lists, with keys converted to strings and values converted according to these + rules. + + 4. Other non-JSON-serializable values will be converted to ``repr()`` strings. """ + if convert is not None: + obj = convert(obj) + if isinstance(obj, (np.generic, np.ndarray)) and obj.ndim == 0: + obj = obj.item() # Unpack NumPy scalars to simple Python values + if isinstance(obj, pd.DataFrame): + obj = obj.to_dict() # Convert DataFrames to dictionaries if isinstance(obj, (str, int, float, bool)) or obj is None: return obj if isinstance(obj, Mapping): - return {str(k): cls._convert_to_json(v) for k, v in obj.items()} + return {str(k): cls.convert_to_json(v, convert) for k, v in obj.items()} if isinstance(obj, Collection): - return [cls._convert_to_json(v) for v in obj] + return [cls.convert_to_json(v, convert) for v in obj] return repr(obj) def log_dict( - self, description: str, dict_data: dict[str, Any], commit_id: int | None = None + self, + description: str, + dict_data: dict[str, Any], + commit_id: int | None = None, + convert: Callable[[Any], Any] | None = None, ) -> DictLog: """ Save the given dictionary data and corresponding metadata in a JSON file, and return a :py:class:`DictLog` with this data and metadata. + Objects will be converted according to :py:meth:`Logger.convert_to_json`, with + ``convert()`` passed to that function. + The log will be tagged with the given commit ID, or the latest commit ID if none is given (and if this Logger has a corresponding ParamDB). """ def make_log(log_metadata: LogMetadata) -> DictLog: - return DictLog(log_metadata, self._convert_to_json(dict_data)) + return DictLog(log_metadata, self.convert_to_json(dict_data, convert)) return self._log(make_log, description, commit_id) def log_props( - self, description: str, obj: Any, commit_id: int | None = None + self, + description: str, + obj: Any, + commit_id: int | None = None, + convert: Callable[[Any], Any] | None = None, ) -> DictLog: """ Save a dictionary of the given object's properties and corresponding metadata in @@ -261,11 +292,14 @@ class Example: value: LoggedProp number: LoggedProp[float] + Objects will be converted according to :py:meth:`Logger.convert_to_json`, with + ``convert()`` passed to that function. + The log will be tagged with the given commit ID, or the latest commit ID if none is given (and if this Logger has a corresponding ParamDB). """ - logged_props: dict[str, Any] = {} obj_class = type(obj) + logged_props: dict[str, Any] = {} try: type_hints = get_type_hints(obj_class) except Exception as exc: @@ -278,4 +312,4 @@ class Example: if type_hint is LoggedProp or get_origin(type_hint) is LoggedProp: if hasattr(obj, name): logged_props[name] = getattr(obj, name) - return self.log_dict(description, logged_props, commit_id) + return self.log_dict(description, logged_props, commit_id, convert)