diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 42ce57b53..f0774410f 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -1538,7 +1538,7 @@ def sample_from_reward( format. """ samples_final = [] - max_reward = self.get_max_reward() + max_reward = self.proxy.get_max_reward() while len(samples_final) < n_samples: if proposal_distribution == "uniform": # TODO: sample only the remaining number of samples