Skip to content

Commit

Permalink
Re-integrate changes from main lost in merge
Browse files Browse the repository at this point in the history
  • Loading branch information
carriepl-mila committed Jun 4, 2024
1 parent 800314c commit a9bea9a
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 20 deletions.
2 changes: 2 additions & 0 deletions config/evaluator/base.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ logprobs_batch_size: 100
logprobs_bootstrap_size: 10000
# Maximum number of test data points to compute log likelihood probs.
max_data_logprobs: 1e5
# Number of points to obtain a grid to estimate the reward density
n_grid: 40000
train_log_period: 1
checkpoints_period: 1000
# List of metrics as per gflownet/eval/evaluator.py:METRICS_NAMES
Expand Down
45 changes: 31 additions & 14 deletions gflownet/evaluator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,10 @@ def eval_top_k(self, it, gfn_states=None, random_states=None):
if not gfn_states:
# sample states from the current gfn
batch = Batch(
env=self.gfn.env, device=self.gfn.device, float_type=self.gfn.float
env=self.gfn.env,
proxy=self.gfn.proxy,
device=self.gfn.device,
float_type=self.gfn.float,
)
self.gfn.random_action_prob = 0
t = time.time()
Expand All @@ -154,7 +157,10 @@ def eval_top_k(self, it, gfn_states=None, random_states=None):
# sample random states from uniform actions
if not random_states:
batch = Batch(
env=self.gfn.env, device=self.gfn.device, float_type=self.gfn.float
env=self.gfn.env,
proxy=self.gfn.proxy,
device=self.gfn.device,
float_type=self.gfn.float,
)
self.gfn.random_action_prob = 1.0
print("[eval_top_k] Sampling at random...", end="\r")
Expand Down Expand Up @@ -264,16 +270,14 @@ def compute_log_prob_metrics(self, x_tt, metrics=None):
lp_metrics["mean_probs_std"] = probs_std.mean().item()

if "reward_batch" in reqs:
rewards_x_tt = self.gfn.env.reward_batch(x_tt)
rewards_x_tt = self.gfn.proxy.rewards(self.gfn.env.states2proxy(x_tt))

if "corr_prob_traj_rewards" in metrics:
rewards_x_tt = self.gfn.env.reward_batch(x_tt)
lp_metrics["corr_prob_traj_rewards"] = np.corrcoef(
np.exp(logprobs_x_tt.cpu().numpy()), rewards_x_tt
)[0, 1]

if "var_logrewards_logp" in metrics:
rewards_x_tt = self.gfn.env.reward_batch(x_tt)
lp_metrics["var_logrewards_logp"] = torch.var(
torch.log(
tfloat(
Expand Down Expand Up @@ -342,9 +346,11 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None):
x_sampled = batch.get_terminating_states()

if "density_true" in dict_tt:
density_true = dict_tt["density_true"]
density_true = torch2np(dict_tt["density_true"])
else:
rewards = self.gfn.env.reward_batch(x_tt)
rewards = torch2np(
self.gfn.proxy.rewards(self.gfn.env.states2proxy(x_tt))
)
z_true = rewards.sum()
density_true = rewards / z_true
with open(self.gfn.buffer.test_pkl, "wb") as f:
Expand All @@ -361,9 +367,8 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None):
elif self.gfn.continuous and hasattr(self.gfn.env, "fit_kde"):
batch, _ = self.gfn.sample_batch(n_forward=self.config.n, train=False)
assert batch.is_valid()
x_sampled = batch.get_terminating_states()
x_sampled = batch.get_terminating_states(proxy=True)
# TODO make it work with conditional env
x_sampled = torch2np(self.gfn.env.states2proxy(x_sampled))
x_tt = torch2np(self.gfn.env.states2proxy(x_tt))
kde_pred = self.gfn.env.fit_kde(
x_sampled,
Expand All @@ -375,8 +380,9 @@ def compute_density_metrics(self, x_tt, dict_tt, metrics=None):
kde_true = dict_tt["kde_true"]
else:
# Sample from reward via rejection sampling
x_from_reward = self.gfn.env.sample_from_reward(n_samples=self.config.n)
x_from_reward = torch2np(self.gfn.env.states2proxy(x_from_reward))
x_from_reward = self.gfn.env.states2proxy(
self.gfn.sample_from_reward(n_samples=self.config.n)
)
# Fit KDE with samples from reward
kde_true = self.gfn.env.fit_kde(
x_from_reward,
Expand Down Expand Up @@ -547,15 +553,26 @@ def plot(self, x_sampled, kde_pred, kde_true, plot_kwargs, **kwargs):
fig_kde_pred = fig_kde_true = fig_reward_samples = None

if hasattr(self.gfn.env, "plot_reward_samples") and x_sampled is not None:
(sample_space_batch, rewards_sample_space) = (
self.gfn.get_sample_space_and_reward()
)
fig_reward_samples = self.gfn.env.plot_reward_samples(
x_sampled, **plot_kwargs
x_sampled,
sample_space_batch,
rewards_sample_space,
**plot_kwargs,
)

if hasattr(self.gfn.env, "plot_kde"):
sample_space_batch, _ = self.gfn.get_sample_space_and_reward()
if kde_pred is not None:
fig_kde_pred = self.gfn.env.plot_kde(kde_pred, **plot_kwargs)
fig_kde_pred = self.gfn.env.plot_kde(
sample_space_batch, kde_pred, **plot_kwargs
)
if kde_true is not None:
fig_kde_true = self.gfn.env.plot_kde(kde_true, **plot_kwargs)
fig_kde_true = self.gfn.env.plot_kde(
sample_space_batch, kde_true, **plot_kwargs
)

return {
"True reward and GFlowNet samples": fig_reward_samples,
Expand Down
33 changes: 31 additions & 2 deletions gflownet/gflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def __init__(
**buffer,
env=self.env,
proxy=self.proxy,
make_train_test=not sample_only,
logger=logger,
)
# Train set statistics and reward normalization constant
Expand Down Expand Up @@ -1212,6 +1211,36 @@ def train(self):
if self.use_context is False:
self.logger.end()

def get_sample_space_and_reward(self):
"""
Returns samples representative of the env state space with their rewards
Returns
-------
sample_space_batch : tensor
Repressentative terminating states for the environment
rewards_sample_space : tensor
Rewards associated with the tates in sample_space_batch
"""
if not hasattr(self, "sample_space_batch"):
if hasattr(self.env, "get_all_terminating_states"):
self.sample_space_batch = self.env.get_all_terminating_states()
elif hasattr(self.env, "get_grid_terminating_states"):
self.sample_space_batch = self.env.get_grid_terminating_states(
self.evaluator.config.n_grid
)
else:
raise NotImplementedError(
"In order to obtain representative terminating states, the "
"environment must implement either get_all_terminating_states() "
"or get_grid_terminating_states()"
)
self.sample_space_batch = self.env.states2proxy(self.sample_space_batch)
if not hasattr(self, "rewards_sample_space"):
self.rewards_sample_space = self.proxy.rewards(self.sample_space_batch)

return self.sample_space_batch, self.rewards_sample_space

# TODO: implement other proposal distributions
# TODO: rethink whether it is needed to convert to reward
def sample_from_reward(
Expand Down Expand Up @@ -1243,7 +1272,7 @@ def sample_from_reward(
format.
"""
samples_final = []
max_reward = self.get_max_reward()
max_reward = self.proxy.get_max_reward()
while len(samples_final) < n_samples:
if proposal_distribution == "uniform":
# TODO: sample only the remaining number of samples
Expand Down
1 change: 0 additions & 1 deletion gflownet/utils/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def __init__(
self,
env,
proxy,
make_train_test=False,
replay_capacity=0,
output_csv=None,
data_path=None,
Expand Down
8 changes: 5 additions & 3 deletions gflownet/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,13 @@ def gflownet_from_config(config):
)

# The proxy is passed to env and used for computing rewards
env = instantiate(
env_maker = instantiate(
config.env,
proxy=proxy,
device=config.device,
float_precision=config.float_precision,
_partial_=True,
)
env = env_maker()

# The evaluator is used to compute metrics and plots
evaluator = instantiate(config.evaluator)
Expand Down Expand Up @@ -296,7 +297,8 @@ def gflownet_from_config(config):
config.gflownet,
device=config.device,
float_precision=config.float_precision,
env=env,
env_maker=env_maker,
proxy=proxy,
forward_policy=forward_policy,
backward_policy=backward_policy,
state_flow=state_flow,
Expand Down

0 comments on commit a9bea9a

Please sign in to comment.