Skip to content

Commit

Permalink
Use random number generator in weighted mode of select
Browse files Browse the repository at this point in the history
  • Loading branch information
alexhernandezgarcia committed Jun 5, 2024
1 parent 72f961b commit 7714146
Showing 1 changed file with 7 additions and 9 deletions.
16 changes: 7 additions & 9 deletions gflownet/utils/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,26 +292,25 @@ def select(
- weighted: data points are sampled with probability proportional to their
score.
Args
----
Parameters
----------
data_dict : dict
A dictionary with samples (key "x") and scores (key "energy" or "rewards").
n : int
The number of samples to select from the dictionary.
mode : str
Sampling mode. Options: permutation, weighted.
rng : np.random.Generator
A numpy random number generator, used for the permutation mode. Ignored
otherwise.
A numpy random number generator, used for the permutation and weighted
modes. If None (default), a generator with a random seed is used.
Returns
-------
list
A batch of n samples, selected from data_dict.
"""
if rng is None:
rng = np.random.default_rng()
if n == 0:
return []
samples = data_dict["x"]
Expand All @@ -320,7 +319,6 @@ def select(
if isinstance(samples, dict):
samples = list(samples.values())
if mode == "permutation":
assert rng is not None
indices = rng.choice(
len(samples),
size=n,
Expand All @@ -339,7 +337,7 @@ def select(
# need to keep its values only
if isinstance(scores, dict):
scores = np.fromiter(scores.values(), dtype=float)
indices = np.random.choice(
indices = rng.choice(
len(samples),
size=n,
replace=False,
Expand Down

0 comments on commit 7714146

Please sign in to comment.