diff --git a/gflownet/envs/tetris.py b/gflownet/envs/tetris.py index 9cd2f722b..bf927de2b 100644 --- a/gflownet/envs/tetris.py +++ b/gflownet/envs/tetris.py @@ -534,6 +534,7 @@ def plot_samples_topk( k_top: int = 10, n_rows: int = 2, dpi: int = 150, + **kwargs, ): """ Plot tetris boards of top K samples. @@ -543,7 +544,7 @@ def plot_samples_topk( samples : list List of terminating states sampled from the policy. rewards : list - List of terminating states. + Rewards of the samples. k_top : int The number of samples that will be included in the plot. The k_top samples with the highest reward are selected. diff --git a/gflownet/evaluator/base.py b/gflownet/evaluator/base.py index b2d02eca7..9085d9f84 100644 --- a/gflownet/evaluator/base.py +++ b/gflownet/evaluator/base.py @@ -98,6 +98,9 @@ def define_new_metrics(self): }, } + # TODO: this method will most likely crash if used (top_k_period != -1) because + # self.gfn.env.top_k_metrics_and_plots still makes use of env.proxy. + # Re-implementing this wil require a non-trivial amount of work. @torch.no_grad() def eval_top_k(self, it, gfn_states=None, random_states=None): """ @@ -124,6 +127,7 @@ def eval_top_k(self, it, gfn_states=None, random_states=None): do_random = it // self.logger.test.top_k_period == 1 duration = None summary = {} + # TODO: Why deepcopy? prob = copy.deepcopy(self.random_action_prob) print() if not gfn_states: @@ -517,9 +521,12 @@ def plot( Plots this evaluator should do, returned as a dict `{str: plt.Figure}` which will be logged. - By default, this method will call the `plot_reward_samples` method of the - GFlowNetAgent's environment, and the `plot_kde` method of the GFlowNetAgent's - environment if it exists for both the `kde_pred` and `kde_true` arguments. + By default, this method will call the following methods of the GFlowNetAgent's + environment if they exist: + + - `plot_reward_samples` + - `plot_kde` (for both the `kde_pred` and `kde_true` arguments) + - `plot_samples_topk` Extend this method to add more plots: @@ -574,8 +581,24 @@ def plot(self, x_sampled, kde_pred, kde_true, plot_kwargs, **kwargs): sample_space_batch, kde_true, **plot_kwargs ) + # TODO: consider moving this to eval_top_k once fixed + if hasattr(self.gfn.env, "plot_samples_topk"): + if x_sampled is None: + batch, _ = self.gfn.sample_batch( + n_forward=self.config.n_top_k, train=False + ) + x_sampled = batch.get_terminating_states() + rewards = self.gfn.proxy.rewards(self.gfn.env.states2proxy(x_sampled)) + fig_samples_topk = self.gfn.env.plot_samples_topk( + x_sampled, + rewards, + self.config.top_k, + **plot_kwargs, + ) + return { "True reward and GFlowNet samples": fig_reward_samples, "GFlowNet KDE Policy": fig_kde_pred, "Reward KDE": fig_kde_true, + "Samples TopK": fig_samples_topk, }