Skip to content

Commit

Permalink
little things
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Sep 26, 2024
1 parent 8aec3e2 commit f684b6f
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 5 deletions.
142 changes: 142 additions & 0 deletions bnpm/file_helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any, Union, Callable
import pickle
import json
import yaml
Expand Down Expand Up @@ -414,6 +415,147 @@ def walk(d, fn):
scipy.io.savemat(filepath, data_cleaned, **kwargs_scipy_savemat)


# def zarr_save(
# obj,
# filepath,
# mode='w',
# mkdir=False,
# allow_overwrite=False,
# function_unsupported: Callable = lambda data: repr(data),
# **kwargs_zarr,
# ):
# """
# Saves an object to a zarr file. Uses recursive approach to save
# hierarchical objects.
# RH 2024

# Args:
# obj (object):
# Object to save. Can be any array or hierarchical object.
# filepath (str):
# Path to save object to.
# mode (str):
# Mode to open file in.
# Can be:
# 'wb' (write binary)
# 'ab' (append binary)
# 'xb' (exclusive write binary. Raises FileExistsError if file already exists.)
# mkdir (bool):
# If True, creates parent directory if it does not exist.
# allow_overwrite (bool):
# If True, allows overwriting of existing file.
# kwargs_zarr (dict):
# Keyword arguments to pass to zarr.save.
# """
# import zarr

# path = prepare_filepath_for_saving(filepath, mkdir=mkdir, allow_overwrite=allow_overwrite)

# import numpy as np
# import scipy.sparse

# def save_data_to_zarr(data, group, name=None):
# """
# Recursively saves complex nested data structures into a Zarr group.

# Parameters:
# - data: The data to save (dict, list, tuple, np.ndarray, int, float, str, bool, or None).
# - group: The Zarr group to save data into.
# - name: The name of the dataset or subgroup (used in recursive calls).
# """
# if isinstance(data, dict):
# # Use the given name or the current group
# sub_group = group.require_group(name) if name else group
# for key, value in data.items():
# # Ensure keys are strings
# key_str = str(key) if not isinstance(key, str) else key
# # Recursively save data
# save_data_to_zarr(value, sub_group, name=key_str)
# elif isinstance(data, (list, tuple)):
# # Create a subgroup for lists and tuples
# sub_group = group.require_group(name) if name else group
# for idx, item in enumerate(data):
# key_str = str(idx)
# save_data_to_zarr(item, sub_group, name=key_str)
# elif isinstance(data, np.ndarray):
# if name is None:
# raise ValueError("Name must be provided for dataset")
# group.create_dataset(name, data=data)
# elif isinstance(data, scipy.sparse.spmatrix):
# if name is None:
# raise ValueError("Name must be provided for dataset")
# group.create_dataset(name, data=data)
# elif isinstance(data, (int, float, str, bool)):
# if name is None:
# raise ValueError("Name must be provided for dataset")
# group.create_dataset(name, data=data)
# elif data is None:
# if name is None:
# raise ValueError("Name must be provided for dataset")
# # Store None as a special attribute
# group.attrs[name] = 'None'
# else:
# # For unsupported types, store the string representation
# if name is None:
# raise ValueError("Name must be provided for attribute")
# group.attrs[name] = repr(data)

# zarr_group = zarr.open(path, mode=mode, **kwargs_zarr)
# save_data_to_zarr(obj, zarr_group, name=None)


# def zarr_load(
# filepath,
# mode='r',
# **kwargs_zarr,
# ):
# """
# Loads a zarr file. Uses recursive approach to load hierarchical
# objects.
# RH 2024

# Args:
# filepath (str):
# Path to zarr file.
# mode (str):
# Mode to open file in.
# kwargs_zarr (dict):
# Keyword arguments to pass to zarr.load.
# """
# import zarr

# # path = prepare_filepath_for_loading(filepath, must_exist=True)
# path = filepath

# def load_data_from_zarr(group):
# """
# Recursively loads complex nested data structures from a Zarr group.

# Parameters:
# - group: The Zarr group to load data from.

# Returns:
# - The loaded data (dict, list, tuple, np.ndarray, int, float, str, bool, or None).
# """
# data = {}
# for key in group.array_keys():
# data[key] = group[key][...]
# for key in group.group_keys():
# data[key] = load_data_from_zarr(group[key])
# for key, value in group.attrs.items():
# if value == 'None':
# data[key] = None
# else:
# try:
# data[key] = eval(value)
# except Exception:
# data[key] = value
# return data

# zarr_group = zarr.open(path, mode=mode, **kwargs_zarr)
# return load_data_from_zarr(zarr_group)


def hash_file(path, type_hash='MD5', buffer_size=65536):
"""
Gets hash of a file.
Expand Down
29 changes: 24 additions & 5 deletions bnpm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,15 @@ def _checker(
at = np.abs(true)
r_diff = diff / at if np.all(at != 0) else np.inf
r_diff_mean, r_diff_max, any_nan = np.nanmean(r_diff), np.nanmax(r_diff), np.any(np.isnan(r_diff))
reason = f"Equivalence: Relative difference: mean={r_diff_mean}, max={r_diff_max}, any_nan={any_nan}"
## fraction of mismatches
n_elements = np.prod(test.shape)
n_mismatches = np.sum(diff > 0)
frac_mismatches = n_mismatches / n_elements
## Use scientific notation and round to 3 decimal places
reason = f"Equivalence: Relative difference: mean={r_diff_mean:.3e}, max={r_diff_max:.3e}, any_nan={any_nan}, n_elements={n_elements}, n_mismatches={n_mismatches}, frac_mismatches={frac_mismatches:.3e}"
else:
reason = f"Values are not numpy numeric types. types: {test.dtype}, {true.dtype}"
elif out == True:
else:
reason = "equivlance"

return out, reason
Expand Down Expand Up @@ -158,9 +163,12 @@ def __call__(
if len(true) != len(test):
result = (False, 'length_mismatch')
else:
result = {}
for idx, (i, j) in enumerate(zip(test, true)):
result[str(idx)] = self.__call__(i, j, path=path + [str(idx)])
if all([isinstance(i, (int, float, complex, np.number)) for i in true]):
result = self._checker(np.array(test), np.array(true), path)
else:
result = {}
for idx, (i, j) in enumerate(zip(test, true)):
result[str(idx)] = self.__call__(i, j, path=path + [str(idx)])
## STRING
elif isinstance(true, str):
result = (test == true, 'equivalence')
Expand All @@ -170,6 +178,17 @@ def __call__(
## NONE
elif true is None:
result = (test is None, 'equivalence')

## OBJECT with __dict__
elif hasattr(true, '__dict__'):
result = {}
for key in true.__dict__:
if key.startswith('_'):
continue
if not hasattr(test, key):
result[str(key)] = (False, 'key not found')
else:
result[str(key)] = self.__call__(getattr(test, key), getattr(true, key), path=path + [str(key)])
## N/A
else:
result = (None, 'not tested')
Expand Down

0 comments on commit f684b6f

Please sign in to comment.