From 77141468365f9302b7f3964fb7e25167b97bf406 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Wed, 5 Jun 2024 16:06:31 -0400 Subject: [PATCH] Use random number generator in weighted mode of select --- gflownet/utils/buffer.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/gflownet/utils/buffer.py b/gflownet/utils/buffer.py index c36339a13..041130c8e 100644 --- a/gflownet/utils/buffer.py +++ b/gflownet/utils/buffer.py @@ -292,26 +292,25 @@ def select( - weighted: data points are sampled with probability proportional to their score. - Args - ---- + Parameters + ---------- data_dict : dict A dictionary with samples (key "x") and scores (key "energy" or "rewards"). - n : int The number of samples to select from the dictionary. - mode : str Sampling mode. Options: permutation, weighted. - rng : np.random.Generator - A numpy random number generator, used for the permutation mode. Ignored - otherwise. + A numpy random number generator, used for the permutation and weighted + modes. If None (default), a generator with a random seed is used. Returns ------- list A batch of n samples, selected from data_dict. """ + if rng is None: + rng = np.random.default_rng() if n == 0: return [] samples = data_dict["x"] @@ -320,7 +319,6 @@ def select( if isinstance(samples, dict): samples = list(samples.values()) if mode == "permutation": - assert rng is not None indices = rng.choice( len(samples), size=n, @@ -339,7 +337,7 @@ def select( # need to keep its values only if isinstance(scores, dict): scores = np.fromiter(scores.values(), dtype=float) - indices = np.random.choice( + indices = rng.choice( len(samples), size=n, replace=False,