From 13ccae511bbe4c65bdb253627324a06c59d4794e Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Sun, 13 Aug 2023 19:21:06 -0400 Subject: [PATCH] Add utility to split dict in batches --- gpax/utils.py | 32 +++++++++++++++++++++++++++++++- tests/test_utils.py | 27 +++++++++++++++++++++++++-- 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/gpax/utils.py b/gpax/utils.py index 3ae286c..0b4bea2 100644 --- a/gpax/utils.py +++ b/gpax/utils.py @@ -7,7 +7,7 @@ Created by Maxim Ziatdinov (email: maxim.ziatdinov@ai4microscopy.com) """ -from typing import Union, Dict, Type +from typing import Union, Dict, Type, List import jax import jax.numpy as jnp @@ -51,6 +51,36 @@ def split_in_batches(X_new: Union[onp.ndarray, jnp.ndarray], return X_split +def split_dict(data: Dict[str, jnp.ndarray], chunk_size: int + ) -> List[Dict[str, jnp.ndarray]]: + """Splits a dictionary of arrays into a list of smaller dictionaries. + + Args: + data: Dictionary containing numpy arrays. + chunk_size: Desired size of the smaller arrays. + + Returns: + List of dictionaries with smaller numpy arrays. + """ + + # Get the length of the arrays + N = len(next(iter(data.values()))) + + # Calculate number of chunks + num_chunks = int(onp.ceil(N / chunk_size)) + + # Split the dictionary + result = [] + for i in range(num_chunks): + start_idx = i * chunk_size + end_idx = min((i+1) * chunk_size, N) + + chunk = {key: value[start_idx:end_idx] for key, value in data.items()} + result.append(chunk) + + return result + + def get_haiku_dict(kernel_params: Dict[str, jnp.ndarray]) -> Dict[str, Dict[str, jnp.ndarray]]: """ Extracts weights and biases from viDKL dictionary into a separate diff --git a/tests/test_utils.py b/tests/test_utils.py index 389b93f..4b95dde 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,11 +1,11 @@ import sys import numpy as onp import jax.numpy as jnp -from numpy.testing import assert_equal, assert_ +from numpy.testing import assert_equal, assert_, assert_array_equal sys.path.insert(0, "../gpax/") -from gpax.utils import preprocess_sparse_image +from gpax.utils import preprocess_sparse_image, split_dict def test_sparse_img_processing(): @@ -24,3 +24,26 @@ def test_sparse_img_processing(): assert_equal(y.shape[0], X.shape[0]) assert_equal(X_full.shape[0], 16*16) assert_equal(X_full.shape[1], 2) + + +def test_split_dict(): + data = { + 'a': jnp.array([1, 2, 3, 4, 5, 6]), + 'b': jnp.array([10, 20, 30, 40, 50, 60]) + } + chunk_size = 4 + + result = split_dict(data, chunk_size) + + expected = [ + {'a': jnp.array([1, 2, 3, 4]), 'b': jnp.array([10, 20, 30, 40])}, + {'a': jnp.array([5, 6]), 'b': jnp.array([50, 60])}, + ] + + # Check that the length of the result matches the expected length + assert len(result) == len(expected) + + # Check that each chunk matches the expected chunk + for r, e in zip(result, expected): + for k in data: + assert_array_equal(r[k], e[k])