Skip to content

Commit

Permalink
Merge pull request #253 from alexhernandezgarcia/new_fl_loss
Browse files Browse the repository at this point in the history
Forward-looking (FL) and Detailed Balance (DB) losses
  • Loading branch information
alexhernandezgarcia authored Dec 26, 2023
2 parents bf65b57 + fe37c34 commit 72b4713
Show file tree
Hide file tree
Showing 19 changed files with 704 additions and 67 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ The above command will overwrite the `env` and `proxy` default configuration wit

Hydra configuration is hierarchical. For instance, a handy variable to change while debugging our code is to avoid logging to wandb. You can do this by setting `logger.do.online=False`.

## GFlowNet loss functions

Currently, the implementation includes the following GFlowNet losses:

- [Flow-matching (FM)](https://arxiv.org/abs/2106.04399): `gflownet=flowmatch`
- [Trajectory balance (TB)](https://arxiv.org/abs/2201.13259): `gflownet=trajectorybalance`
- [Detailed balance (DB)](https://arxiv.org/abs/2201.13259): `gflownet=detailedbalance`
- [Forward-looking (FL)](https://arxiv.org/abs/2302.01687): `gflownet=forwardlooking`

## Logging to wandb

The repository supports logging of train and evaluation metrics to [wandb.ai](https://wandb.ai), but it is disabled by default. In order to enable it, set the configuration variable `logger.do.online` to `True`.
47 changes: 47 additions & 0 deletions config/experiments/simple_tetris.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# @package _global_

defaults:
- override /env: tetris
- override /gflownet: trajectorybalance
- override /policy: mlp
- override /proxy: tetris
- override /logger: wandb

env:
reward_func: boltzmann
reward_beta: 10.0
width: 4
height: 4
pieces: ["I", "O", "J", "L", "T"]
rotations: [0, 90, 180, 270]
buffer:
# replay_capacity: 0
test:
type: random
output_csv: simple_tetris_val.csv
output_pkl: simple_tetris_val.pkl
n: 100

gflownet:
random_action_prob: 0.3
optimizer:
n_train_steps: 10000
lr_z_mult: 100
lr: 0.0001

policy:
forward:
type: mlp
n_hid: 128
n_layers: 5

backward:
shared_weights: True
checkpoint: null
reload_ckpt: False

device: cpu
logger:
do:
online: True
project_name: simple_tetris
9 changes: 9 additions & 0 deletions config/gflownet/detailedbalance.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
defaults:
- gflownet
- state_flow: mlp

optimizer:
loss: detailedbalance
lr: 0.0001
lr_decay_period: 1000000
lr_decay_gamma: 0.5
9 changes: 9 additions & 0 deletions config/gflownet/forwardlooking.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
defaults:
- gflownet
- state_flow: mlp

optimizer:
loss: forwardlooking
lr: 0.0001
lr_decay_period: 1000000
lr_decay_gamma: 0.5
2 changes: 2 additions & 0 deletions config/gflownet/gflownet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ optimizer:
# From original implementation
bootstrap_tau: 0.0
clip_grad_norm: 0.0
# State flow modelling
state_flow: null
# If True, compute rewards in batches
batch_reward: True
# Force zero probability of sampling invalid actions
Expand Down
9 changes: 9 additions & 0 deletions config/gflownet/state_flow/mlp.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_target_: gflownet.policy.state_flow.StateFlow

config:
type: mlp
n_hid: 128
n_layers: 2
checkpoint: null
reload_ckpt: False
shared_weights: False
7 changes: 7 additions & 0 deletions config/policy/mlp_detailedbalance.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- mlp

backward:
shared_weights: True
checkpoint: null
reload_ckpt: False
7 changes: 7 additions & 0 deletions config/policy/mlp_forwardlooking.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
defaults:
- mlp

backward:
shared_weights: True
checkpoint: null
reload_ckpt: False
2 changes: 1 addition & 1 deletion config/user/alex.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
logdir:
root: /network/scratch/h/hernanga/logs/gflownet
root: /home/alex/logs/gflownet
data:
root: /home/mila/h/hernanga/gflownet/data
alanine_dipeptide: /home/mila/h/hernanga/gflownet/data/alanine_dipeptide_conformers_1.npy
4 changes: 2 additions & 2 deletions gflownet/envs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,13 +762,13 @@ def traj2readable(self, traj=None):
"""
return str(traj).replace("(", "[").replace(")", "]").replace(",", "")

def reward(self, state=None, done=None):
def reward(self, state=None, done=None, do_non_terminating=False):
"""
Computes the reward of a state
"""
state = self._get_state(state)
done = self._get_done(done)
if done is False:
if not done and not do_non_terminating:
return tfloat(0.0, float_type=self.float, device=self.device)
return self.proxy2reward(
self.proxy(torch.unsqueeze(self.state2proxy(state), dim=0))[0]
Expand Down
180 changes: 169 additions & 11 deletions gflownet/gflownet.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
logger,
num_empirical_loss,
oracle,
state_flow=None,
active_learning=False,
sample_only=False,
replay_sampling="permutation",
Expand Down Expand Up @@ -79,6 +80,12 @@ def __init__(
elif optimizer.loss in ["trajectorybalance", "tb"]:
self.loss = "trajectorybalance"
self.logZ = nn.Parameter(torch.ones(optimizer.z_dim) * 150.0 / 64)
elif optimizer.loss in ["detailedbalance", "db"]:
self.loss = "detailedbalance"
self.logZ = None
elif optimizer.loss in ["forwardlooking", "fl"]:
self.loss = "forwardlooking"
self.logZ = None
else:
print("Unkown loss. Using flowmatch as default")
self.loss = "flowmatch"
Expand Down Expand Up @@ -121,7 +128,8 @@ def __init__(
print(f"\tStd score: {self.buffer.test['energies'].std()}")
print(f"\tMin score: {self.buffer.test['energies'].min()}")
print(f"\tMax score: {self.buffer.test['energies'].max()}")
# Policy models

# Models
self.forward_policy = forward_policy
if self.forward_policy.checkpoint is not None:
self.logger.set_forward_policy_ckpt_path(self.forward_policy.checkpoint)
Expand All @@ -133,6 +141,7 @@ def __init__(
print("Reloaded GFN forward policy model Checkpoint")
else:
self.logger.set_forward_policy_ckpt_path(None)

self.backward_policy = backward_policy
self.logger.set_backward_policy_ckpt_path(None)
if self.backward_policy.checkpoint is not None:
Expand All @@ -145,6 +154,14 @@ def __init__(
print("Reloaded GFN backward policy model Checkpoint")
else:
self.logger.set_backward_policy_ckpt_path(None)

self.state_flow = state_flow
if self.state_flow is not None and self.state_flow.checkpoint is not None:
self.logger.set_state_flow_ckpt_path(self.state_flow.checkpoint)
# TODO: add the logic and conditions to reload a model
else:
self.logger.set_state_flow_ckpt_path(None)

# Optimizer
if self.forward_policy.is_model:
self.target = copy.deepcopy(self.forward_policy.model)
Expand Down Expand Up @@ -178,14 +195,16 @@ def __init__(
self.nll_tt = 0.0

def parameters(self):
if self.backward_policy.is_model is False:
return list(self.forward_policy.model.parameters())
elif self.loss == "trajectorybalance":
return list(self.forward_policy.model.parameters()) + list(
self.backward_policy.model.parameters()
)
else:
raise ValueError("Backward Policy cannot be a nn in flowmatch.")
parameters = list(self.forward_policy.model.parameters())
if self.backward_policy.is_model:
if self.loss == "flowmatch":
raise ValueError("Backward Policy cannot be a model in flowmatch.")
parameters += list(self.backward_policy.model.parameters())
if self.state_flow is not None:
if self.loss not in ["detailedbalance", "forwardlooking"]:
raise ValueError(f"State flow cannot be trained with {self.loss} loss.")
parameters += list(self.state_flow.model.parameters())
return parameters

def sample_actions(
self,
Expand Down Expand Up @@ -662,6 +681,137 @@ def trajectorybalance_loss(self, it, batch):
)
return loss, loss, loss

def detailedbalance_loss(self, it, batch):
"""
Computes the Detailed Balance GFlowNet loss of a batch
Reference : https://arxiv.org/pdf/2201.13259.pdf (eq 11)
Args
----
it : int
Iteration
batch : Batch
A batch of data, containing all the states in the trajectories.
Returns
-------
loss : float
term_loss : float
Loss of the terminal nodes only
nonterm_loss : float
Loss of the intermediate nodes only
"""

assert batch.is_valid()
# Get necessary tensors from batch
states = batch.get_states(policy=False)
states_policy = batch.get_states(policy=True)
actions = batch.get_actions()
parents = batch.get_parents(policy=False)
parents_policy = batch.get_parents(policy=True)
done = batch.get_done()
rewards = batch.get_terminating_rewards(sort_by="insertion")

# Get logprobs
masks_f = batch.get_masks_forward(of_parents=True)
policy_output_f = self.forward_policy(parents_policy)
logprobs_f = self.env.get_logprobs(
policy_output_f, actions, masks_f, parents, is_backward=False
)
masks_b = batch.get_masks_backward()
policy_output_b = self.backward_policy(states_policy)
logprobs_b = self.env.get_logprobs(
policy_output_b, actions, masks_b, states, is_backward=True
)

# Get logflows
logflows_states = self.state_flow(states_policy)
logflows_states[done.eq(1)] = torch.log(rewards)
# TODO: Optimise by reusing logflows_states and batch.get_parent_indices
logflows_parents = self.state_flow(parents_policy)

# Detailed balance loss
loss_all = (logflows_parents + logprobs_f - logflows_states - logprobs_b).pow(2)
loss = loss_all.mean()
loss_terminating = loss_all[done].mean()
loss_intermediate = loss_all[~done].mean()
return loss, loss_terminating, loss_intermediate

def forwardlooking_loss(self, it, batch):
"""
Computes the Forward-Looking GFlowNet loss of a batch
Reference : https://arxiv.org/pdf/2302.01687.pdf
Args
----
it : int
Iteration
batch : Batch
A batch of data, containing all the states in the trajectories.
Returns
-------
loss : float
term_loss : float
Loss of the terminal nodes only
nonterm_loss : float
Loss of the intermediate nodes only
"""

assert batch.is_valid()
# Get necessary tensors from batch
states = batch.get_states(policy=False)
states_policy = batch.get_states(policy=True)
actions = batch.get_actions()
parents = batch.get_parents(policy=False)
parents_policy = batch.get_parents(policy=True)
rewards_states = batch.get_rewards(do_non_terminating=True)
rewards_parents = batch.get_rewards_parents()
done = batch.get_done()

# Get logprobs
masks_f = batch.get_masks_forward(of_parents=True)
policy_output_f = self.forward_policy(parents_policy)
logprobs_f = self.env.get_logprobs(
policy_output_f, actions, masks_f, parents, is_backward=False
)
masks_b = batch.get_masks_backward()
policy_output_b = self.backward_policy(states_policy)
logprobs_b = self.env.get_logprobs(
policy_output_b, actions, masks_b, states, is_backward=True
)

# Get FL logflows
logflflows_states = self.state_flow(states_policy)
# Log FL flow of terminal states is 0 (eq. 9 of paper)
logflflows_states[done.eq(1)] = 0.0
# TODO: Optimise by reusing logflows_states and batch.get_parent_indices
logflflows_parents = self.state_flow(parents_policy)

# Get energies transitions
energies_transitions = torch.log(rewards_parents) - torch.log(rewards_states)

# Forward-looking loss
loss_all = (
logflflows_parents
- logflflows_states
+ logprobs_f
- logprobs_b
+ energies_transitions
).pow(2)
loss = loss_all.mean()
loss_terminating = loss_all[done].mean()
loss_intermediate = loss_all[~done].mean()
return loss, loss_terminating, loss_intermediate

@torch.no_grad()
def estimate_logprobs_data(
self,
Expand Down Expand Up @@ -874,6 +1024,10 @@ def train(self):
losses = self.trajectorybalance_loss(
it * self.ttsr + j, batch
) # returns (opt loss, *metrics)
elif self.loss == "detailedbalance":
losses = self.detailedbalance_loss(it * self.ttsr + j, batch)
elif self.loss == "forwardlooking":
losses = self.forwardlooking_loss(it * self.ttsr + j, batch)
else:
print("Unknown loss!")
# TODO: deal with this in a better way
Expand Down Expand Up @@ -937,7 +1091,9 @@ def train(self):
times.update({"log": t1_log - t0_log})
# Save intermediate models
t0_model = time.time()
self.logger.save_models(self.forward_policy, self.backward_policy, step=it)
self.logger.save_models(
self.forward_policy, self.backward_policy, self.state_flow, step=it
)
t1_model = time.time()
times.update({"save_interim_model": t1_model - t0_model})

Expand Down Expand Up @@ -966,7 +1122,9 @@ def train(self):
self.logger.log_time(times, use_context=self.use_context)

# Save final model
self.logger.save_models(self.forward_policy, self.backward_policy, final=True)
self.logger.save_models(
self.forward_policy, self.backward_policy, self.state_flow, final=True
)
# Close logger
if self.use_context is False:
self.logger.end()
Expand Down
Loading

0 comments on commit 72b4713

Please sign in to comment.