From f684b6f240d994ea266d0a82efa5c5a88074c14a Mon Sep 17 00:00:00 2001 From: RichieHakim Date: Thu, 26 Sep 2024 18:16:44 -0400 Subject: [PATCH] little things --- bnpm/file_helpers.py | 142 +++++++++++++++++++++++++++++++++++++++++++ bnpm/testing.py | 29 +++++++-- 2 files changed, 166 insertions(+), 5 deletions(-) diff --git a/bnpm/file_helpers.py b/bnpm/file_helpers.py index cab2391..b0916a9 100644 --- a/bnpm/file_helpers.py +++ b/bnpm/file_helpers.py @@ -1,3 +1,4 @@ +from typing import Any, Union, Callable import pickle import json import yaml @@ -414,6 +415,147 @@ def walk(d, fn): scipy.io.savemat(filepath, data_cleaned, **kwargs_scipy_savemat) +# def zarr_save( +# obj, +# filepath, +# mode='w', +# mkdir=False, +# allow_overwrite=False, +# function_unsupported: Callable = lambda data: repr(data), +# **kwargs_zarr, +# ): +# """ +# Saves an object to a zarr file. Uses recursive approach to save +# hierarchical objects. +# RH 2024 + +# Args: +# obj (object): +# Object to save. Can be any array or hierarchical object. +# filepath (str): +# Path to save object to. +# mode (str): +# Mode to open file in. +# Can be: +# 'wb' (write binary) +# 'ab' (append binary) +# 'xb' (exclusive write binary. Raises FileExistsError if file already exists.) +# mkdir (bool): +# If True, creates parent directory if it does not exist. +# allow_overwrite (bool): +# If True, allows overwriting of existing file. +# kwargs_zarr (dict): +# Keyword arguments to pass to zarr.save. +# """ +# import zarr + +# path = prepare_filepath_for_saving(filepath, mkdir=mkdir, allow_overwrite=allow_overwrite) + +# import numpy as np +# import scipy.sparse + +# def save_data_to_zarr(data, group, name=None): +# """ +# Recursively saves complex nested data structures into a Zarr group. + +# Parameters: +# - data: The data to save (dict, list, tuple, np.ndarray, int, float, str, bool, or None). +# - group: The Zarr group to save data into. +# - name: The name of the dataset or subgroup (used in recursive calls). +# """ +# if isinstance(data, dict): +# # Use the given name or the current group +# sub_group = group.require_group(name) if name else group +# for key, value in data.items(): +# # Ensure keys are strings +# key_str = str(key) if not isinstance(key, str) else key +# # Recursively save data +# save_data_to_zarr(value, sub_group, name=key_str) +# elif isinstance(data, (list, tuple)): +# # Create a subgroup for lists and tuples +# sub_group = group.require_group(name) if name else group +# for idx, item in enumerate(data): +# key_str = str(idx) +# save_data_to_zarr(item, sub_group, name=key_str) +# elif isinstance(data, np.ndarray): +# if name is None: +# raise ValueError("Name must be provided for dataset") +# group.create_dataset(name, data=data) +# elif isinstance(data, scipy.sparse.spmatrix): +# if name is None: +# raise ValueError("Name must be provided for dataset") +# group.create_dataset(name, data=data) +# elif isinstance(data, (int, float, str, bool)): +# if name is None: +# raise ValueError("Name must be provided for dataset") +# group.create_dataset(name, data=data) +# elif data is None: +# if name is None: +# raise ValueError("Name must be provided for dataset") +# # Store None as a special attribute +# group.attrs[name] = 'None' +# else: +# # For unsupported types, store the string representation +# if name is None: +# raise ValueError("Name must be provided for attribute") +# group.attrs[name] = repr(data) + +# zarr_group = zarr.open(path, mode=mode, **kwargs_zarr) +# save_data_to_zarr(obj, zarr_group, name=None) + + +# def zarr_load( +# filepath, +# mode='r', +# **kwargs_zarr, +# ): +# """ +# Loads a zarr file. Uses recursive approach to load hierarchical +# objects. +# RH 2024 + +# Args: +# filepath (str): +# Path to zarr file. +# mode (str): +# Mode to open file in. +# kwargs_zarr (dict): +# Keyword arguments to pass to zarr.load. +# """ +# import zarr + +# # path = prepare_filepath_for_loading(filepath, must_exist=True) +# path = filepath + +# def load_data_from_zarr(group): +# """ +# Recursively loads complex nested data structures from a Zarr group. + +# Parameters: +# - group: The Zarr group to load data from. + +# Returns: +# - The loaded data (dict, list, tuple, np.ndarray, int, float, str, bool, or None). +# """ +# data = {} +# for key in group.array_keys(): +# data[key] = group[key][...] +# for key in group.group_keys(): +# data[key] = load_data_from_zarr(group[key]) +# for key, value in group.attrs.items(): +# if value == 'None': +# data[key] = None +# else: +# try: +# data[key] = eval(value) +# except Exception: +# data[key] = value +# return data + +# zarr_group = zarr.open(path, mode=mode, **kwargs_zarr) +# return load_data_from_zarr(zarr_group) + + def hash_file(path, type_hash='MD5', buffer_size=65536): """ Gets hash of a file. diff --git a/bnpm/testing.py b/bnpm/testing.py index e44c23e..3a371b1 100644 --- a/bnpm/testing.py +++ b/bnpm/testing.py @@ -90,10 +90,15 @@ def _checker( at = np.abs(true) r_diff = diff / at if np.all(at != 0) else np.inf r_diff_mean, r_diff_max, any_nan = np.nanmean(r_diff), np.nanmax(r_diff), np.any(np.isnan(r_diff)) - reason = f"Equivalence: Relative difference: mean={r_diff_mean}, max={r_diff_max}, any_nan={any_nan}" + ## fraction of mismatches + n_elements = np.prod(test.shape) + n_mismatches = np.sum(diff > 0) + frac_mismatches = n_mismatches / n_elements + ## Use scientific notation and round to 3 decimal places + reason = f"Equivalence: Relative difference: mean={r_diff_mean:.3e}, max={r_diff_max:.3e}, any_nan={any_nan}, n_elements={n_elements}, n_mismatches={n_mismatches}, frac_mismatches={frac_mismatches:.3e}" else: reason = f"Values are not numpy numeric types. types: {test.dtype}, {true.dtype}" - elif out == True: + else: reason = "equivlance" return out, reason @@ -158,9 +163,12 @@ def __call__( if len(true) != len(test): result = (False, 'length_mismatch') else: - result = {} - for idx, (i, j) in enumerate(zip(test, true)): - result[str(idx)] = self.__call__(i, j, path=path + [str(idx)]) + if all([isinstance(i, (int, float, complex, np.number)) for i in true]): + result = self._checker(np.array(test), np.array(true), path) + else: + result = {} + for idx, (i, j) in enumerate(zip(test, true)): + result[str(idx)] = self.__call__(i, j, path=path + [str(idx)]) ## STRING elif isinstance(true, str): result = (test == true, 'equivalence') @@ -170,6 +178,17 @@ def __call__( ## NONE elif true is None: result = (test is None, 'equivalence') + + ## OBJECT with __dict__ + elif hasattr(true, '__dict__'): + result = {} + for key in true.__dict__: + if key.startswith('_'): + continue + if not hasattr(test, key): + result[str(key)] = (False, 'key not found') + else: + result[str(key)] = self.__call__(getattr(test, key), getattr(true, key), path=path + [str(key)]) ## N/A else: result = (None, 'not tested')