Skip to content

Commit

Permalink
Add utility to split dict in batches
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Aug 13, 2023
1 parent 728b59b commit 13ccae5
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 3 deletions.
32 changes: 31 additions & 1 deletion gpax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Created by Maxim Ziatdinov (email: [email protected])
"""

from typing import Union, Dict, Type
from typing import Union, Dict, Type, List

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -51,6 +51,36 @@ def split_in_batches(X_new: Union[onp.ndarray, jnp.ndarray],
return X_split


def split_dict(data: Dict[str, jnp.ndarray], chunk_size: int
) -> List[Dict[str, jnp.ndarray]]:
"""Splits a dictionary of arrays into a list of smaller dictionaries.
Args:
data: Dictionary containing numpy arrays.
chunk_size: Desired size of the smaller arrays.
Returns:
List of dictionaries with smaller numpy arrays.
"""

# Get the length of the arrays
N = len(next(iter(data.values())))

# Calculate number of chunks
num_chunks = int(onp.ceil(N / chunk_size))

# Split the dictionary
result = []
for i in range(num_chunks):
start_idx = i * chunk_size
end_idx = min((i+1) * chunk_size, N)

chunk = {key: value[start_idx:end_idx] for key, value in data.items()}
result.append(chunk)

return result


def get_haiku_dict(kernel_params: Dict[str, jnp.ndarray]) -> Dict[str, Dict[str, jnp.ndarray]]:
"""
Extracts weights and biases from viDKL dictionary into a separate
Expand Down
27 changes: 25 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import sys
import numpy as onp
import jax.numpy as jnp
from numpy.testing import assert_equal, assert_
from numpy.testing import assert_equal, assert_, assert_array_equal

sys.path.insert(0, "../gpax/")

from gpax.utils import preprocess_sparse_image
from gpax.utils import preprocess_sparse_image, split_dict


def test_sparse_img_processing():
Expand All @@ -24,3 +24,26 @@ def test_sparse_img_processing():
assert_equal(y.shape[0], X.shape[0])
assert_equal(X_full.shape[0], 16*16)
assert_equal(X_full.shape[1], 2)


def test_split_dict():
data = {
'a': jnp.array([1, 2, 3, 4, 5, 6]),
'b': jnp.array([10, 20, 30, 40, 50, 60])
}
chunk_size = 4

result = split_dict(data, chunk_size)

expected = [
{'a': jnp.array([1, 2, 3, 4]), 'b': jnp.array([10, 20, 30, 40])},
{'a': jnp.array([5, 6]), 'b': jnp.array([50, 60])},
]

# Check that the length of the result matches the expected length
assert len(result) == len(expected)

# Check that each chunk matches the expected chunk
for r, e in zip(result, expected):
for k in data:
assert_array_equal(r[k], e[k])

0 comments on commit 13ccae5

Please sign in to comment.