Skip to content

Commit

Permalink
Evaluator: add samples_topk to plot(); add TODOs
Browse files Browse the repository at this point in the history
  • Loading branch information
alexhernandezgarcia committed Jun 5, 2024
1 parent 4f4fcad commit 25a45e2
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
3 changes: 2 additions & 1 deletion gflownet/envs/tetris.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ def plot_samples_topk(
k_top: int = 10,
n_rows: int = 2,
dpi: int = 150,
**kwargs,
):
"""
Plot tetris boards of top K samples.
Expand All @@ -543,7 +544,7 @@ def plot_samples_topk(
samples : list
List of terminating states sampled from the policy.
rewards : list
List of terminating states.
Rewards of the samples.
k_top : int
The number of samples that will be included in the plot. The k_top samples
with the highest reward are selected.
Expand Down
29 changes: 26 additions & 3 deletions gflownet/evaluator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ def define_new_metrics(self):
},
}

# TODO: this method will most likely crash if used (top_k_period != -1) because
# self.gfn.env.top_k_metrics_and_plots still makes use of env.proxy.
# Re-implementing this wil require a non-trivial amount of work.
@torch.no_grad()
def eval_top_k(self, it, gfn_states=None, random_states=None):
"""
Expand All @@ -124,6 +127,7 @@ def eval_top_k(self, it, gfn_states=None, random_states=None):
do_random = it // self.logger.test.top_k_period == 1
duration = None
summary = {}
# TODO: Why deepcopy?
prob = copy.deepcopy(self.random_action_prob)
print()
if not gfn_states:
Expand Down Expand Up @@ -517,9 +521,12 @@ def plot(
Plots this evaluator should do, returned as a dict `{str: plt.Figure}` which
will be logged.
By default, this method will call the `plot_reward_samples` method of the
GFlowNetAgent's environment, and the `plot_kde` method of the GFlowNetAgent's
environment if it exists for both the `kde_pred` and `kde_true` arguments.
By default, this method will call the following methods of the GFlowNetAgent's
environment if they exist:
- `plot_reward_samples`
- `plot_kde` (for both the `kde_pred` and `kde_true` arguments)
- `plot_samples_topk`
Extend this method to add more plots:
Expand Down Expand Up @@ -574,8 +581,24 @@ def plot(self, x_sampled, kde_pred, kde_true, plot_kwargs, **kwargs):
sample_space_batch, kde_true, **plot_kwargs
)

# TODO: consider moving this to eval_top_k once fixed
if hasattr(self.gfn.env, "plot_samples_topk"):
if x_sampled is None:
batch, _ = self.gfn.sample_batch(
n_forward=self.config.n_top_k, train=False
)
x_sampled = batch.get_terminating_states()
rewards = self.gfn.proxy.rewards(self.gfn.env.states2proxy(x_sampled))
fig_samples_topk = self.gfn.env.plot_samples_topk(
x_sampled,
rewards,
self.config.top_k,
**plot_kwargs,
)

return {
"True reward and GFlowNet samples": fig_reward_samples,
"GFlowNet KDE Policy": fig_kde_pred,
"Reward KDE": fig_kde_true,
"Samples TopK": fig_samples_topk,
}

0 comments on commit 25a45e2

Please sign in to comment.