diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index 9085d9f84..adc633c73 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -557,7 +557,7 @@ def plot(self, x_sampled, kde_pred, kde_true, plot_kwargs, **kwargs): values are the figures. """ - fig_kde_pred = fig_kde_true = fig_reward_samples = None + fig_kde_pred = fig_kde_true = fig_reward_samples = fig_samples_topk = None if hasattr(self.gfn.env, "plot_reward_samples") and x_sampled is not None: (sample_space_batch, rewards_sample_space) = ( diff --git a/tests/gflownet/envs/test_tetris.py b/tests/gflownet/envs/test_tetris.py index 3bcb98528..998e6c727 100644 --- a/tests/gflownet/envs/test_tetris.py +++ b/tests/gflownet/envs/test_tetris.py @@ -547,6 +547,7 @@ def setup(self, env): self.env = env self.repeats = { "test__reset__state_is_source": 10, + "test__gflownet_minimal_runs": 0, } self.n_states = {} # TODO: Populate. @@ -559,6 +560,7 @@ def setup(self, env_full): self.env = env_full self.repeats = { "test__reset__state_is_source": 10, + "test__gflownet_minimal_runs": 0, } self.n_states = {} # TODO: Populate. diff --git a/tests/gflownet/evaluator/test_base.py b/tests/gflownet/evaluator/test_base.py index 85f5f81b5..f9616aa82 100644 --- a/tests/gflownet/evaluator/test_base.py +++ b/tests/gflownet/evaluator/test_base.py @@ -237,6 +237,9 @@ def test__eval(gflownet_for_tests, parameterization): elif parameterization == "ctorus": for figname, fig in figs.items(): assert isinstance(figname, str) + # plot_samples_topk not implemented in ctorus + if figname == "Samples TopK": + continue assert isinstance(fig, plt.Figure) else: raise ValueError(f"Unknown parameterization: {parameterization}")