From 2cff80219dff150bdfa976aae13a0e4d5ac6d35b Mon Sep 17 00:00:00 2001 From: Maxim Ziatdinov <34245227+ziatdinovmax@users.noreply.github.com> Date: Wed, 16 Aug 2023 16:46:33 -0400 Subject: [PATCH] Add utility for random sampling dict with params --- gpax/utils.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/gpax/utils.py b/gpax/utils.py index 0b4bea2..9a2c690 100644 --- a/gpax/utils.py +++ b/gpax/utils.py @@ -81,6 +81,25 @@ def split_dict(data: Dict[str, jnp.ndarray], chunk_size: int return result +def random_sample_dict(data: Dict[str, jnp.ndarray], + num_samples: int) -> Dict[str, jnp.ndarray]: + """Returns a dictionary with a smaller number of consistent random samples for each array. + + Args: + data: Dictionary containing numpy arrays. + num_samples: Number of random samples required. + + Returns: + Dictionary with the consistently sampled arrays. + """ + + # Generate unique random indices + indices = onp.random.choice( + len(next(iter(data.values()))), size=num_samples, replace=False) + + return {key: value[indices] for key, value in data.items()} + + def get_haiku_dict(kernel_params: Dict[str, jnp.ndarray]) -> Dict[str, Dict[str, jnp.ndarray]]: """ Extracts weights and biases from viDKL dictionary into a separate