Skip to content

Commit

Permalink
Add utility for random sampling dict with params
Browse files Browse the repository at this point in the history
  • Loading branch information
ziatdinovmax committed Aug 16, 2023
1 parent 484f70c commit 2cff802
Showing 1 changed file with 19 additions and 0 deletions.
19 changes: 19 additions & 0 deletions gpax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,25 @@ def split_dict(data: Dict[str, jnp.ndarray], chunk_size: int
return result


def random_sample_dict(data: Dict[str, jnp.ndarray],
num_samples: int) -> Dict[str, jnp.ndarray]:
"""Returns a dictionary with a smaller number of consistent random samples for each array.
Args:
data: Dictionary containing numpy arrays.
num_samples: Number of random samples required.
Returns:
Dictionary with the consistently sampled arrays.
"""

# Generate unique random indices
indices = onp.random.choice(
len(next(iter(data.values()))), size=num_samples, replace=False)

return {key: value[indices] for key, value in data.items()}


def get_haiku_dict(kernel_params: Dict[str, jnp.ndarray]) -> Dict[str, Dict[str, jnp.ndarray]]:
"""
Extracts weights and biases from viDKL dictionary into a separate
Expand Down

0 comments on commit 2cff802

Please sign in to comment.