From c2559f76cff1dcbc4576671938cf0167d9f47ef4 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Tue, 14 Nov 2023 13:37:26 -0500 Subject: [PATCH 01/34] simple tetris config --- config/experiments/simple_tetris.yaml | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 config/experiments/simple_tetris.yaml diff --git a/config/experiments/simple_tetris.yaml b/config/experiments/simple_tetris.yaml new file mode 100644 index 000000000..82673b3fc --- /dev/null +++ b/config/experiments/simple_tetris.yaml @@ -0,0 +1,17 @@ +defaults: + - override /env: tetris + - override /gflownet: trajectorybalance + - override /policy: mlp + - override /proxy: tetris + - override /logger: wandb + +env: + width: 4 + height: 4 + pieces: ["I", "L", "O"] + +device: cpu +logger: + do: + online: True + project_name: simple_tetris \ No newline at end of file From 21cffef9f58de7c093658d0b9fb3b7a0ee1c1d95 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Tue, 14 Nov 2023 14:17:09 -0500 Subject: [PATCH 02/34] config fixes --- config/experiments/simple_tetris.yaml | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/config/experiments/simple_tetris.yaml b/config/experiments/simple_tetris.yaml index 82673b3fc..41b5c247f 100644 --- a/config/experiments/simple_tetris.yaml +++ b/config/experiments/simple_tetris.yaml @@ -1,3 +1,5 @@ +# @package _global_ + defaults: - override /env: tetris - override /gflownet: trajectorybalance @@ -9,6 +11,19 @@ env: width: 4 height: 4 pieces: ["I", "L", "O"] + rotations: [0, 90] + buffer: + # replay_capacity: 0 + test: + type: uniform + output_csv: simple_tetris_val.csv + output_pkl: simple_tetris_val.pkl + +gflownet: + optimizer: + n_train_steps: 10000 + lr_z_mult: 100 + lr: 0.0001 device: cpu logger: From 8829988bc63753528c428253563485feacff664f Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Thu, 16 Nov 2023 16:42:51 -0500 Subject: [PATCH 03/34] final config for simple tetris --- config/experiments/simple_tetris.yaml | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/config/experiments/simple_tetris.yaml b/config/experiments/simple_tetris.yaml index 41b5c247f..716169dae 100644 --- a/config/experiments/simple_tetris.yaml +++ b/config/experiments/simple_tetris.yaml @@ -8,23 +8,38 @@ defaults: - override /logger: wandb env: + reward_func: boltzmann + reward_beta: 10.0 width: 4 height: 4 - pieces: ["I", "L", "O"] - rotations: [0, 90] + pieces: ["I", "O", "J", "L", "T"] + rotations: [0, 90, 180, 270] buffer: # replay_capacity: 0 test: - type: uniform + type: random output_csv: simple_tetris_val.csv output_pkl: simple_tetris_val.pkl + n: 100 gflownet: + random_action_prob: 0.3 optimizer: n_train_steps: 10000 lr_z_mult: 100 lr: 0.0001 +policy: + forward: + type: mlp + n_hid: 128 + n_layers: 5 + + backward: + shared_weights: True + checkpoint: null + reload_ckpt: False + device: cpu logger: do: From 97430aa5941003b6fab7d618111037f33da22f7b Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Thu, 23 Nov 2023 15:37:26 -0500 Subject: [PATCH 04/34] state flow class --- config/main.yaml | 1 + config/state_flow/mlp.yaml | 9 ++++ gflownet/gflownet.py | 57 ++++++++++++++------- gflownet/policy/base.py | 93 ++++++++++++++++++----------------- gflownet/policy/state_flow.py | 25 ++++++++++ gflownet/utils/batch.py | 12 ++++- gflownet/utils/logger.py | 16 +++++- gflownet/utils/policy.py | 2 +- main.py | 12 +++++ 9 files changed, 159 insertions(+), 68 deletions(-) create mode 100644 config/state_flow/mlp.yaml create mode 100644 gflownet/policy/state_flow.py diff --git a/config/main.yaml b/config/main.yaml index 7ea98e735..c550bfaf6 100644 --- a/config/main.yaml +++ b/config/main.yaml @@ -3,6 +3,7 @@ defaults: - env: grid - gflownet: flowmatch - policy: mlp_${gflownet} + - state_flow: null - proxy: corners - logger: wandb - user: alex diff --git a/config/state_flow/mlp.yaml b/config/state_flow/mlp.yaml new file mode 100644 index 000000000..6ccf772ee --- /dev/null +++ b/config/state_flow/mlp.yaml @@ -0,0 +1,9 @@ +_target_: gflownet.policy.state_flow.StateFlow + +config: + type: mlp + n_hid: 128 + n_layers: 2 + checkpoint: null + reload_ckpt: False + shared_weights: False \ No newline at end of file diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index c9cac3146..e46640173 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -50,6 +50,7 @@ def __init__( logger, num_empirical_loss, oracle, + state_flow=None, active_learning=False, sample_only=False, replay_sampling="permutation", @@ -76,9 +77,15 @@ def __init__( if optimizer.loss in ["flowmatch", "flowmatching"]: self.loss = "flowmatch" self.logZ = None + self.non_terminal_rewards = False elif optimizer.loss in ["trajectorybalance", "tb"]: self.loss = "trajectorybalance" self.logZ = nn.Parameter(torch.ones(optimizer.z_dim) * 150.0 / 64) + self.non_terminal_rewards = False + elif optimizer.loss in ["forwardlooking", "fl"]: + self.loss = "forwardlooking" + self.logZ = None + self.non_terminal_rewards = True else: print("Unkown loss. Using flowmatch as default") self.loss = "flowmatch" @@ -121,7 +128,8 @@ def __init__( print(f"\tStd score: {self.buffer.test['energies'].std()}") print(f"\tMin score: {self.buffer.test['energies'].min()}") print(f"\tMax score: {self.buffer.test['energies'].max()}") - # Policy models + + # Models self.forward_policy = forward_policy if self.forward_policy.checkpoint is not None: self.logger.set_forward_policy_ckpt_path(self.forward_policy.checkpoint) @@ -133,6 +141,7 @@ def __init__( print("Reloaded GFN forward policy model Checkpoint") else: self.logger.set_forward_policy_ckpt_path(None) + self.backward_policy = backward_policy self.logger.set_backward_policy_ckpt_path(None) if self.backward_policy.checkpoint is not None: @@ -145,6 +154,14 @@ def __init__( print("Reloaded GFN backward policy model Checkpoint") else: self.logger.set_backward_policy_ckpt_path(None) + + self.state_flow = state_flow + if self.state_flow is not None and self.state_flow.checkpoint is not None: + self.logger.set_state_flow_ckpt_path(self.state_flow.checkpoint) + # TODO: add the logic and conditions to reload a model + else: + self.logger.set_state_flow_ckpt_path(None) + # Optimizer if self.forward_policy.is_model: self.target = copy.deepcopy(self.forward_policy.model) @@ -178,14 +195,16 @@ def __init__( self.nll_tt = 0.0 def parameters(self): - if self.backward_policy.is_model is False: - return list(self.forward_policy.model.parameters()) - elif self.loss == "trajectorybalance": - return list(self.forward_policy.model.parameters()) + list( - self.backward_policy.model.parameters() - ) - else: - raise ValueError("Backward Policy cannot be a nn in flowmatch.") + parameters = list(self.forward_policy.model.parameters()) + if self.backward_policy.is_model: + if self.loss == "flowmatch": + raise ValueError("Backward Policy cannot be a nn in flowmatch.") + parameters += list(self.backward_policy.model.parameters()) + if self.state_flow is not None: + if self.loss != "forwardlooking": + raise ValueError(f"State flow cannot be trained in {self.loss} loss.") + parameters += list(self.state_flow.model.parameters()) + return parameters def sample_actions( self, @@ -405,12 +424,12 @@ def sample_batch( "actions_envs": 0.0, } t0_all = time.time() - batch = Batch(env=self.env, device=self.device, float_type=self.float) + batch = Batch(env=self.env, device=self.device, float_type=self.float, non_terminal_rewards=self.non_terminal_rewards) # ON-POLICY FORWARD trajectories t0_forward = time.time() envs = [self.env.copy().reset(idx) for idx in range(n_forward)] - batch_forward = Batch(env=self.env, device=self.device, float_type=self.float) + batch_forward = Batch(env=self.env, device=self.device, float_type=self.float, non_terminal_rewards=self.non_terminal_rewards) while envs: # Sample actions t0_a_envs = time.time() @@ -432,7 +451,7 @@ def sample_batch( # TRAIN BACKWARD trajectories t0_train = time.time() envs = [self.env.copy().reset(idx) for idx in range(n_train)] - batch_train = Batch(env=self.env, device=self.device, float_type=self.float) + batch_train = Batch(env=self.env, device=self.device, float_type=self.float, non_terminal_rewards=self.non_terminal_rewards) if n_train > 0 and self.buffer.train_pkl is not None: with open(self.buffer.train_pkl, "rb") as f: dict_tr = pickle.load(f) @@ -463,7 +482,7 @@ def sample_batch( # REPLAY BACKWARD trajectories t0_replay = time.time() - batch_replay = Batch(env=self.env, device=self.device, float_type=self.float) + batch_replay = Batch(env=self.env, device=self.device, float_type=self.float, non_terminal_rewards=self.non_terminal_rewards) if n_replay > 0 and self.buffer.replay_pkl is not None: with open(self.buffer.replay_pkl, "rb") as f: dict_replay = pickle.load(f) @@ -753,7 +772,7 @@ def estimate_logprobs_data( end_batch = min(batch_size, n_states) pbar = tqdm(total=n_states) while init_batch < n_states: - batch = Batch(env=self.env, device=self.device, float_type=self.float) + batch = Batch(env=self.env, device=self.device, float_type=self.float, non_terminal_rewards=self.non_terminal_rewards) # Create an environment for each data point and trajectory and set the state envs = [] for state_idx in range(init_batch, end_batch): @@ -852,7 +871,7 @@ def train(self): self.logger.log_metrics(metrics, use_context=self.use_context, step=it) self.logger.log_summary(summary) t0_iter = time.time() - batch = Batch(env=self.env, device=self.device, float_type=self.float) + batch = Batch(env=self.env, device=self.device, float_type=self.float, non_terminal_rewards=self.non_terminal_rewards) for j in range(self.sttr): sub_batch, times = self.sample_batch( n_forward=self.batch_size.forward, @@ -932,7 +951,7 @@ def train(self): times.update({"log": t1_log - t0_log}) # Save intermediate models t0_model = time.time() - self.logger.save_models(self.forward_policy, self.backward_policy, step=it) + self.logger.save_models(self.forward_policy, self.backward_policy, self.state_flow, step=it) t1_model = time.time() times.update({"save_interim_model": t1_model - t0_model}) @@ -961,7 +980,7 @@ def train(self): self.logger.log_time(times, use_context=self.use_context) # Save final model - self.logger.save_models(self.forward_policy, self.backward_policy, final=True) + self.logger.save_models(self.forward_policy, self.backward_policy, self.state_flow, final=True) # Close logger if self.use_context is False: self.logger.end() @@ -1136,7 +1155,7 @@ def test_top_k(self, it, progress=False, gfn_states=None, random_states=None): print() if not gfn_states: # sample states from the current gfn - batch = Batch(env=self.env, device=self.device, float_type=self.float) + batch = Batch(env=self.env, device=self.device, float_type=self.float, non_terminal_rewards=self.non_terminal_rewards) self.random_action_prob = 0 t = time.time() print("Sampling from GFN...", end="\r") @@ -1159,7 +1178,7 @@ def test_top_k(self, it, progress=False, gfn_states=None, random_states=None): if do_random: # sample random states from uniform actions if not random_states: - batch = Batch(env=self.env, device=self.device, float_type=self.float) + batch = Batch(env=self.env, device=self.device, float_type=self.float, non_terminal_rewards=self.non_terminal_rewards) self.random_action_prob = 1.0 print("[test_top_k] Sampling at random...", end="\r") for b in batch_with_rest( diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index 766231481..5122891a4 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -1,55 +1,35 @@ import torch from omegaconf import OmegaConf from torch import nn - +from abc import ABC, abstractmethod from gflownet.utils.common import set_device, set_float_precision -class Policy: - def __init__(self, config, env, device, float_precision, base=None): +class ModelBase(ABC): + def __init__(self, config, input_dim, device, float_precision, base=None): # Device and float precision self.device = set_device(device) self.float = set_float_precision(float_precision) - # Input and output dimensions - self.state_dim = env.policy_input_dim - self.fixed_output = torch.tensor(env.fixed_policy_output).to( - dtype=self.float, device=self.device - ) - self.random_output = torch.tensor(env.random_policy_output).to( - dtype=self.float, device=self.device - ) - self.output_dim = len(self.fixed_output) + # Input dimension + self.input_dim = input_dim + # Must be redefined in the children classes + self.output_dim = None + # Optional base model self.base = base self.parse_config(config) - self.instantiate() def parse_config(self, config): # If config is null, default to uniform if config is None: config = OmegaConf.create() config.type = "uniform" - if "checkpoint" in config: - self.checkpoint = config.checkpoint - else: - self.checkpoint = None - if "shared_weights" in config: - self.shared_weights = config.shared_weights - else: - self.shared_weights = False - if "n_hid" in config: - self.n_hid = config.n_hid - else: - self.n_hid = None - if "n_layers" in config: - self.n_layers = config.n_layers - else: - self.n_layers = None - if "tail" in config: - self.tail = config.tail - else: - self.tail = [] + self.checkpoint = config.get("checkpoint", None) + self.shared_weights = config.get("shared_weights", False) + self.n_hid = config.get("n_hid", None) + self.n_layers = config.get("n_layers", None) + self.tail = config.get("tail", []) if "type" in config: self.type = config.type elif self.shared_weights: @@ -57,18 +37,9 @@ def parse_config(self, config): else: raise "Policy type must be defined if shared_weights is False" + @abstractmethod def instantiate(self): - if self.type == "fixed": - self.model = self.fixed_distribution - self.is_model = False - elif self.type == "uniform": - self.model = self.uniform_distribution - self.is_model = False - elif self.type == "mlp": - self.model = self.make_mlp(nn.LeakyReLU()).to(self.device) - self.is_model = True - else: - raise "Policy model type not defined" + pass def __call__(self, states): return self.model(states) @@ -95,7 +66,7 @@ def make_mlp(self, activation): return mlp elif self.shared_weights == False: layers_dim = ( - [self.state_dim] + [self.n_hid] * self.n_layers + [(self.output_dim)] + [self.input_dim] + [self.n_hid] * self.n_layers + [(self.output_dim)] ) mlp = nn.Sequential( *( @@ -118,6 +89,38 @@ def make_mlp(self, activation): "Base Model must be provided when shared_weights is set to True" ) + + +class Policy(ModelBase): + def __init__(self, config, env, device, float_precision, base=None): + super().__init__(config, env.policy_input_dim, device, float_precision, base) + + # Outputs + + self.fixed_output = torch.tensor(env.fixed_policy_output).to( + dtype=self.float, device=self.device + ) + self.random_output = torch.tensor(env.random_policy_output).to( + dtype=self.float, device=self.device + ) + self.output_dim = len(self.fixed_output) + + self.instantiate() + + + def instantiate(self): + if self.type == "fixed": + self.model = self.fixed_distribution + self.is_model = False + elif self.type == "uniform": + self.model = self.uniform_distribution + self.is_model = False + elif self.type == "mlp": + self.model = self.make_mlp(nn.LeakyReLU()).to(self.device) + self.is_model = True + else: + raise "Policy model type not defined" + def fixed_distribution(self, states): """ Returns the fixed distribution specified by the environment. diff --git a/gflownet/policy/state_flow.py b/gflownet/policy/state_flow.py new file mode 100644 index 000000000..6417dfac2 --- /dev/null +++ b/gflownet/policy/state_flow.py @@ -0,0 +1,25 @@ +import torch +from torch import nn + +from gflownet.utils.common import set_device, set_float_precision +from gflownet.policy.base import ModelBase + +class StateFlow(ModelBase): + """ + Takes state in the policy format and predicts its flow (a scalar) + """ + def __init__(self, config, env, device, float_precision, base=None): + super().__init__(config, env.policy_input_dim, device, float_precision, base) + + # output dim + self.output_dim = 1 + + # Instantiate neural network + self.instantiate() + + def instantiate(self): + if self.type == "mlp": + self.model = self.make_mlp(nn.LeakyReLU()).to(self.device) + self.is_model = True + else: + raise "StateFlow model type not defined" \ No newline at end of file diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index a35f01ddf..6ca38ff6d 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -39,6 +39,7 @@ def __init__( env: Optional[GFlowNetEnv] = None, device: Union[str, torch.device] = "cpu", float_type: Union[int, torch.dtype] = 32, + non_terminal_rewards: bool = False ): """ env : GFlowNetEnv @@ -56,6 +57,8 @@ def __init__( self.device = set_device(device) # Float precision self.float = set_float_precision(float_type) + # Whether rewards should be computed for non-terminal states + self.non_terminal_rewards = non_terminal_rewards # Generic environment, properties and dictionary of state and forward mask of # source (as tensor) if env is not None: @@ -843,10 +846,15 @@ def _compute_rewards(self): rewards: torch.tensor Tensor of rewards. """ - states_proxy_done = self.get_terminating_states(proxy=True) + self.rewards = torch.zeros(len(self), dtype=self.float, device=self.device) + states_proxy_done = self.get_terminating_states(proxy=True) done = self.get_done() - if len(done) > 0: + if self.non_terminal_rewards: + self.rewards = self.env.proxy2reward( + self.env.proxy(self.states2proxy()) + ) + elif len(done) > 0: self.rewards[done] = self.env.proxy2reward( self.env.proxy(states_proxy_done) ) diff --git a/gflownet/utils/logger.py b/gflownet/utils/logger.py index e9556f9c5..0012317e8 100644 --- a/gflownet/utils/logger.py +++ b/gflownet/utils/logger.py @@ -144,6 +144,12 @@ def set_backward_policy_ckpt_path(self, ckpt_id: str = None): else: self.pb_ckpt_path = self.ckpts_dir / f"{ckpt_id}_" + def set_state_flow_ckpt_path(self, ckpt_id: str = None): + if ckpt_id is None: + self.sf_ckpt_path = None + else: + self.sf_ckpt_path = self.ckpts_dir / f"{ckpt_id}_" + def progressbar_update( self, pbar, losses, rewards, jsd, step, use_context=True, n_mean=100 ): @@ -356,7 +362,7 @@ def log_test_metrics( ) def save_models( - self, forward_policy, backward_policy, step: int = 1e9, final=False + self, forward_policy, backward_policy, state_flow, step: int = 1e9, final=False ): if self.do_checkpoints(step) or final: if final: @@ -376,6 +382,14 @@ def save_models( path = self.pb_ckpt_path.parent / stem torch.save(backward_policy.model.state_dict(), path) + if state_flow is not None and self.sf_ckpt_path is not None: + stem = self.sf_ckpt_path.stem + self.context + ckpt_id + ".ckpt" + path = self.sf_ckpt_path.parent / stem + torch.save(state_flow.model.state_dict(), path) + + + + def log_time(self, times: dict, use_context: bool): if self.do.times: times = {"time_{}".format(k): v for k, v in times.items()} diff --git a/gflownet/utils/policy.py b/gflownet/utils/policy.py index 973e16c84..971e8a2a9 100644 --- a/gflownet/utils/policy.py +++ b/gflownet/utils/policy.py @@ -30,4 +30,4 @@ def parse_policy_config(config: DictConfig, kind: str) -> Optional[DictConfig]: del policy_config.backward del policy_config.shared - return policy_config + return policy_config \ No newline at end of file diff --git a/main.py b/main.py index 127e59001..79ae422f8 100644 --- a/main.py +++ b/main.py @@ -61,6 +61,17 @@ def main(config): base=forward_policy, ) + if config.gflownet.optimizer.loss in ["forwardlooking", "fl"]: + state_flow = hydra.utils.instantiate( + config.state_flow, + env=env, + device=config.device, + float_precision=config.float_precision, + base=forward_policy, + ) + else: + state_flow = None + gflownet = hydra.utils.instantiate( config.gflownet, device=config.device, @@ -68,6 +79,7 @@ def main(config): env=env, forward_policy=forward_policy, backward_policy=backward_policy, + state_flow=state_flow, buffer=config.env.buffer, logger=logger, ) From ebb2737268fc256fda81e953e6f62b8adf46f01f Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Thu, 23 Nov 2023 20:41:21 -0500 Subject: [PATCH 05/34] fl loss functionin gfn + updates in batch class --- gflownet/gflownet.py | 70 ++++++++++++++++++++++++++++++++++++++ gflownet/utils/batch.py | 75 ++++++++++++++++++++++++++++++++++++----- 2 files changed, 136 insertions(+), 9 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index e46640173..44c0e030b 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -682,6 +682,74 @@ def trajectorybalance_loss(self, it, batch): ) return loss, loss, loss + def forwardlooking_loss(self, it, batch): + """ + Computes the Forward-Looking GFlowNet loss of a batch + Reference : https://arxiv.org/pdf/2302.01687.pdf + + Args + ---- + it : int + Iteration + + batch : Batch + A batch of data, containing all the states in the trajectories. + + + Returns + ------- + loss : float + + term_loss : float + Loss of the terminal nodes only + + nonterm_loss : float + Loss of the intermediate nodes only + """ + + assert batch.is_valid() + # Get necessary tensors from batch + states_policy = batch.get_states(policy=True) + states = batch.get_states(policy=False) + actions = batch.get_actions() + parents_policy = batch.get_parents(policy=True) + parents = batch.get_parents(policy=False) + traj_indices = batch.get_trajectory_indices(consecutive=True) + done = batch.get_done() + + masks_b = batch.get_masks_backward() + policy_output_b = self.backward_policy(states_policy) + logprobs_bkw = self.env.get_logprobs( + policy_output_b, actions, masks_b, states, is_backward=True + ) + masks_f = batch.get_masks_forward(of_parents=True) + policy_output_f = self.forward_policy(parents_policy) + logprobs_fwd = self.env.get_logprobs( + policy_output_f, actions, masks_f, parents, is_backward=False + ) + + + states_log_flflow = self.state_flow(states_policy) + # forward-looking flow is 1 in the terminal states + states_log_flflow[done.eq(1)] = 0. + # Can be optimised by reusing states_log_flflow and batch.get_parent_indices + parents_log_flflow = self.state_flow(parents_policy) + + assert batch.non_terminal_rewards + rewards_states = batch.get_rewards() + rewards_parents = batch.get_rewards_parents() + energies_states = -torch.log(rewards_states) + energies_parents = -torch.log(rewards_parents) + + per_node_loss = (parents_log_flflow - states_log_flflow + logprobs_fwd - logprobs_bkw + + energies_states - energies_parents).pow(2) + + term_loss = per_node_loss[done].mean() + nonterm_loss = per_node_loss[~done].mean() + loss = per_node_loss.mean() + + return loss, term_loss, nonterm_loss + @torch.no_grad() def estimate_logprobs_data( self, @@ -888,6 +956,8 @@ def train(self): losses = self.trajectorybalance_loss( it * self.ttsr + j, batch ) # returns (opt loss, *metrics) + elif self.loss == "forwardlooking": + losses = self.forwardlooking_loss(it * self.ttsr + j, batch) else: print("Unknown loss!") # TODO: deal with this in a better way diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index 6ca38ff6d..c1d72a703 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -91,12 +91,14 @@ def __init__( self.states_policy = None self.parents_policy = None # Flags for available items - self.parents_available = True + self.parents_available = False self.parents_policy_available = False self.parents_all_available = False self.masks_forward_available = False self.masks_backward_available = False self.rewards_available = False + self.rewards_parents_available = False + self.rewards_source_available = False def __len__(self): return self.size @@ -524,6 +526,16 @@ def get_parents( else: return self.parents + def get_parents_indices(self): + if self.parents_available is False: + self._compute_parents() + return self.parents_indices + + def get_parent_is_source(self): + if self.parents_available is False: + self._compute_parents() + return self.parents_indices == -1 + def _compute_parents(self): """ Obtains the parent (single parent for each state) of all states in the batch. @@ -539,18 +551,24 @@ def _compute_parents(self): self.parents_available is set to True. """ self.parents = [] + self.parents_indices = [] indices = [] # Iterate over the trajectories to obtain the parents from the states for traj_idx, batch_indices in self.trajectories.items(): # parent is source self.parents.append(self.envs[traj_idx].source) + # there's no source state in the batch + self.parents_indices.append(-1) # parent is not source # TODO: check if tensor and sort without iter self.parents.extend([self.states[idx] for idx in batch_indices[:-1]]) + self.parents_indices.extend([idx for idx in batch_indices[:-1]]) indices.extend(batch_indices) # Sort parents list in the same order as states # TODO: check if tensor and sort without iter self.parents = [self.parents[indices.index(idx)] for idx in range(len(self))] + self.parents_indices = tlong([self.parents_indices[indices.index(idx)] for idx in range(len(self))], + device=self.device) self.parents_available = True # TODO: consider converting directly from self.parents @@ -835,31 +853,70 @@ def get_rewards( if self.rewards_available is False or force_recompute is True: self._compute_rewards() return self.rewards - + def _compute_rewards(self): """ Computes rewards for all self.states by first converting the states into proxy - format. - - Returns - ------- - rewards: torch.tensor - Tensor of rewards. + format. The result is stored in self.rewards as a torch.tensor """ self.rewards = torch.zeros(len(self), dtype=self.float, device=self.device) - states_proxy_done = self.get_terminating_states(proxy=True) done = self.get_done() if self.non_terminal_rewards: self.rewards = self.env.proxy2reward( self.env.proxy(self.states2proxy()) ) elif len(done) > 0: + states_proxy_done = self.get_terminating_states(proxy=True) self.rewards[done] = self.env.proxy2reward( self.env.proxy(states_proxy_done) ) self.rewards_available = True + def get_rewards_parents(self) -> TensorType["n_states"]: + """ + Returns the rewards of all parents in the batch + """ + if not self.rewards_parents_available: + self._compute_rewards_parents() + return self.rewards_parents + + def _compute_rewards_parents(self): + """ + Computes rewards of the self.parents by reusing rewards of the states (i.e. self.rewards). + Stores the result in self.rewards_parents + """ + state_rewards = self.get_rewards() + self.rewards_parents = torch.zeros_like(state_rewards) + parent_is_source = self.get_parent_is_source() + parent_indices = self.get_parents_indices() + self.rewards_parents[~parent_is_source] = self.rewards[parent_indices[~parent_is_source]] + rewards_source = self.get_rewards_source() + self.rewards_parents[parent_is_source] = rewards_source[parent_is_source] + self.rewards_parents_available = True + + def get_rewards_source(self) -> TensorType["n_states"]: + """ + Returns rewards of the corresponding source states for each state in the batch. + """ + if not self.rewards_source_available: + self._compute_rewards_source() + return self.rewards_source + + def _compute_rewards_source(self): + """ + Computes a tensor of length len(self.states) with rewards of the corresponding source states. + Stores the result in self.rewards_source + """ + # This will not work if source is randomised + if not self.conditional: + source_proxy = self.env.state2proxy(self.env.source) + reward_source = self.env.proxy2reward(self.env.proxy(source_proxy)) + self.rewards_source = reward_source.expand(len(self)) + else: + raise NotImplementedError + self.rewards_source_available = True + def get_terminating_states( self, sort_by: str = "insertion", From 64f7c3ba2217e62593e3f78dedd0c2268fb16e46 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Fri, 24 Nov 2023 14:35:19 -0500 Subject: [PATCH 06/34] black, isort --- gflownet/gflownet.py | 95 +++++++++++++++++++++++-------- gflownet/policy/base.py | 12 ++-- gflownet/policy/state_flow.py | 10 ++-- gflownet/utils/batch.py | 24 ++++---- gflownet/utils/logger.py | 3 - gflownet/utils/policy.py | 2 +- main.py | 2 +- mila/launch.py | 2 +- scripts/dav_mp20_stats.py | 1 + scripts/fit_lattice_proxy.py | 1 - scripts/mp20_matbench_lp_range.py | 1 - 11 files changed, 101 insertions(+), 52 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 44c0e030b..59fce6542 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -128,7 +128,7 @@ def __init__( print(f"\tStd score: {self.buffer.test['energies'].std()}") print(f"\tMin score: {self.buffer.test['energies'].min()}") print(f"\tMax score: {self.buffer.test['energies'].max()}") - + # Models self.forward_policy = forward_policy if self.forward_policy.checkpoint is not None: @@ -157,8 +157,8 @@ def __init__( self.state_flow = state_flow if self.state_flow is not None and self.state_flow.checkpoint is not None: - self.logger.set_state_flow_ckpt_path(self.state_flow.checkpoint) - # TODO: add the logic and conditions to reload a model + self.logger.set_state_flow_ckpt_path(self.state_flow.checkpoint) + # TODO: add the logic and conditions to reload a model else: self.logger.set_state_flow_ckpt_path(None) @@ -202,7 +202,7 @@ def parameters(self): parameters += list(self.backward_policy.model.parameters()) if self.state_flow is not None: if self.loss != "forwardlooking": - raise ValueError(f"State flow cannot be trained in {self.loss} loss.") + raise ValueError(f"State flow cannot be trained in {self.loss} loss.") parameters += list(self.state_flow.model.parameters()) return parameters @@ -424,12 +424,22 @@ def sample_batch( "actions_envs": 0.0, } t0_all = time.time() - batch = Batch(env=self.env, device=self.device, float_type=self.float, non_terminal_rewards=self.non_terminal_rewards) + batch = Batch( + env=self.env, + device=self.device, + float_type=self.float, + non_terminal_rewards=self.non_terminal_rewards, + ) # ON-POLICY FORWARD trajectories t0_forward = time.time() envs = [self.env.copy().reset(idx) for idx in range(n_forward)] - batch_forward = Batch(env=self.env, device=self.device, float_type=self.float, non_terminal_rewards=self.non_terminal_rewards) + batch_forward = Batch( + env=self.env, + device=self.device, + float_type=self.float, + non_terminal_rewards=self.non_terminal_rewards, + ) while envs: # Sample actions t0_a_envs = time.time() @@ -451,7 +461,12 @@ def sample_batch( # TRAIN BACKWARD trajectories t0_train = time.time() envs = [self.env.copy().reset(idx) for idx in range(n_train)] - batch_train = Batch(env=self.env, device=self.device, float_type=self.float, non_terminal_rewards=self.non_terminal_rewards) + batch_train = Batch( + env=self.env, + device=self.device, + float_type=self.float, + non_terminal_rewards=self.non_terminal_rewards, + ) if n_train > 0 and self.buffer.train_pkl is not None: with open(self.buffer.train_pkl, "rb") as f: dict_tr = pickle.load(f) @@ -482,7 +497,12 @@ def sample_batch( # REPLAY BACKWARD trajectories t0_replay = time.time() - batch_replay = Batch(env=self.env, device=self.device, float_type=self.float, non_terminal_rewards=self.non_terminal_rewards) + batch_replay = Batch( + env=self.env, + device=self.device, + float_type=self.float, + non_terminal_rewards=self.non_terminal_rewards, + ) if n_replay > 0 and self.buffer.replay_pkl is not None: with open(self.buffer.replay_pkl, "rb") as f: dict_replay = pickle.load(f) @@ -720,18 +740,17 @@ def forwardlooking_loss(self, it, batch): masks_b = batch.get_masks_backward() policy_output_b = self.backward_policy(states_policy) logprobs_bkw = self.env.get_logprobs( - policy_output_b, actions, masks_b, states, is_backward=True - ) + policy_output_b, actions, masks_b, states, is_backward=True + ) masks_f = batch.get_masks_forward(of_parents=True) policy_output_f = self.forward_policy(parents_policy) logprobs_fwd = self.env.get_logprobs( - policy_output_f, actions, masks_f, parents, is_backward=False - ) - + policy_output_f, actions, masks_f, parents, is_backward=False + ) states_log_flflow = self.state_flow(states_policy) # forward-looking flow is 1 in the terminal states - states_log_flflow[done.eq(1)] = 0. + states_log_flflow[done.eq(1)] = 0.0 # Can be optimised by reusing states_log_flflow and batch.get_parent_indices parents_log_flflow = self.state_flow(parents_policy) @@ -741,9 +760,15 @@ def forwardlooking_loss(self, it, batch): energies_states = -torch.log(rewards_states) energies_parents = -torch.log(rewards_parents) - per_node_loss = (parents_log_flflow - states_log_flflow + logprobs_fwd - logprobs_bkw + - energies_states - energies_parents).pow(2) - + per_node_loss = ( + parents_log_flflow + - states_log_flflow + + logprobs_fwd + - logprobs_bkw + + energies_states + - energies_parents + ).pow(2) + term_loss = per_node_loss[done].mean() nonterm_loss = per_node_loss[~done].mean() loss = per_node_loss.mean() @@ -840,7 +865,12 @@ def estimate_logprobs_data( end_batch = min(batch_size, n_states) pbar = tqdm(total=n_states) while init_batch < n_states: - batch = Batch(env=self.env, device=self.device, float_type=self.float, non_terminal_rewards=self.non_terminal_rewards) + batch = Batch( + env=self.env, + device=self.device, + float_type=self.float, + non_terminal_rewards=self.non_terminal_rewards, + ) # Create an environment for each data point and trajectory and set the state envs = [] for state_idx in range(init_batch, end_batch): @@ -939,7 +969,12 @@ def train(self): self.logger.log_metrics(metrics, use_context=self.use_context, step=it) self.logger.log_summary(summary) t0_iter = time.time() - batch = Batch(env=self.env, device=self.device, float_type=self.float, non_terminal_rewards=self.non_terminal_rewards) + batch = Batch( + env=self.env, + device=self.device, + float_type=self.float, + non_terminal_rewards=self.non_terminal_rewards, + ) for j in range(self.sttr): sub_batch, times = self.sample_batch( n_forward=self.batch_size.forward, @@ -1021,7 +1056,9 @@ def train(self): times.update({"log": t1_log - t0_log}) # Save intermediate models t0_model = time.time() - self.logger.save_models(self.forward_policy, self.backward_policy, self.state_flow, step=it) + self.logger.save_models( + self.forward_policy, self.backward_policy, self.state_flow, step=it + ) t1_model = time.time() times.update({"save_interim_model": t1_model - t0_model}) @@ -1050,7 +1087,9 @@ def train(self): self.logger.log_time(times, use_context=self.use_context) # Save final model - self.logger.save_models(self.forward_policy, self.backward_policy, self.state_flow, final=True) + self.logger.save_models( + self.forward_policy, self.backward_policy, self.state_flow, final=True + ) # Close logger if self.use_context is False: self.logger.end() @@ -1225,7 +1264,12 @@ def test_top_k(self, it, progress=False, gfn_states=None, random_states=None): print() if not gfn_states: # sample states from the current gfn - batch = Batch(env=self.env, device=self.device, float_type=self.float, non_terminal_rewards=self.non_terminal_rewards) + batch = Batch( + env=self.env, + device=self.device, + float_type=self.float, + non_terminal_rewards=self.non_terminal_rewards, + ) self.random_action_prob = 0 t = time.time() print("Sampling from GFN...", end="\r") @@ -1248,7 +1292,12 @@ def test_top_k(self, it, progress=False, gfn_states=None, random_states=None): if do_random: # sample random states from uniform actions if not random_states: - batch = Batch(env=self.env, device=self.device, float_type=self.float, non_terminal_rewards=self.non_terminal_rewards) + batch = Batch( + env=self.env, + device=self.device, + float_type=self.float, + non_terminal_rewards=self.non_terminal_rewards, + ) self.random_action_prob = 1.0 print("[test_top_k] Sampling at random...", end="\r") for b in batch_with_rest( diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index 5122891a4..bf1ef2198 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -1,7 +1,9 @@ +from abc import ABC, abstractmethod + import torch from omegaconf import OmegaConf from torch import nn -from abc import ABC, abstractmethod + from gflownet.utils.common import set_device, set_float_precision @@ -14,7 +16,7 @@ def __init__(self, config, input_dim, device, float_precision, base=None): self.input_dim = input_dim # Must be redefined in the children classes self.output_dim = None - + # Optional base model self.base = base @@ -90,13 +92,12 @@ def make_mlp(self, activation): ) - class Policy(ModelBase): def __init__(self, config, env, device, float_precision, base=None): super().__init__(config, env.policy_input_dim, device, float_precision, base) - # Outputs - + # Outputs + self.fixed_output = torch.tensor(env.fixed_policy_output).to( dtype=self.float, device=self.device ) @@ -107,7 +108,6 @@ def __init__(self, config, env, device, float_precision, base=None): self.instantiate() - def instantiate(self): if self.type == "fixed": self.model = self.fixed_distribution diff --git a/gflownet/policy/state_flow.py b/gflownet/policy/state_flow.py index 6417dfac2..8542c8a15 100644 --- a/gflownet/policy/state_flow.py +++ b/gflownet/policy/state_flow.py @@ -1,19 +1,21 @@ import torch from torch import nn -from gflownet.utils.common import set_device, set_float_precision from gflownet.policy.base import ModelBase +from gflownet.utils.common import set_device, set_float_precision + class StateFlow(ModelBase): """ Takes state in the policy format and predicts its flow (a scalar) """ + def __init__(self, config, env, device, float_precision, base=None): super().__init__(config, env.policy_input_dim, device, float_precision, base) - + # output dim self.output_dim = 1 - + # Instantiate neural network self.instantiate() @@ -22,4 +24,4 @@ def instantiate(self): self.model = self.make_mlp(nn.LeakyReLU()).to(self.device) self.is_model = True else: - raise "StateFlow model type not defined" \ No newline at end of file + raise "StateFlow model type not defined" diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index c1d72a703..b6da02691 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -39,7 +39,7 @@ def __init__( env: Optional[GFlowNetEnv] = None, device: Union[str, torch.device] = "cpu", float_type: Union[int, torch.dtype] = 32, - non_terminal_rewards: bool = False + non_terminal_rewards: bool = False, ): """ env : GFlowNetEnv @@ -567,8 +567,10 @@ def _compute_parents(self): # Sort parents list in the same order as states # TODO: check if tensor and sort without iter self.parents = [self.parents[indices.index(idx)] for idx in range(len(self))] - self.parents_indices = tlong([self.parents_indices[indices.index(idx)] for idx in range(len(self))], - device=self.device) + self.parents_indices = tlong( + [self.parents_indices[indices.index(idx)] for idx in range(len(self))], + device=self.device, + ) self.parents_available = True # TODO: consider converting directly from self.parents @@ -853,19 +855,17 @@ def get_rewards( if self.rewards_available is False or force_recompute is True: self._compute_rewards() return self.rewards - + def _compute_rewards(self): """ Computes rewards for all self.states by first converting the states into proxy format. The result is stored in self.rewards as a torch.tensor """ - + self.rewards = torch.zeros(len(self), dtype=self.float, device=self.device) done = self.get_done() if self.non_terminal_rewards: - self.rewards = self.env.proxy2reward( - self.env.proxy(self.states2proxy()) - ) + self.rewards = self.env.proxy2reward(self.env.proxy(self.states2proxy())) elif len(done) > 0: states_proxy_done = self.get_terminating_states(proxy=True) self.rewards[done] = self.env.proxy2reward( @@ -884,13 +884,15 @@ def get_rewards_parents(self) -> TensorType["n_states"]: def _compute_rewards_parents(self): """ Computes rewards of the self.parents by reusing rewards of the states (i.e. self.rewards). - Stores the result in self.rewards_parents + Stores the result in self.rewards_parents """ state_rewards = self.get_rewards() self.rewards_parents = torch.zeros_like(state_rewards) parent_is_source = self.get_parent_is_source() parent_indices = self.get_parents_indices() - self.rewards_parents[~parent_is_source] = self.rewards[parent_indices[~parent_is_source]] + self.rewards_parents[~parent_is_source] = self.rewards[ + parent_indices[~parent_is_source] + ] rewards_source = self.get_rewards_source() self.rewards_parents[parent_is_source] = rewards_source[parent_is_source] self.rewards_parents_available = True @@ -902,7 +904,7 @@ def get_rewards_source(self) -> TensorType["n_states"]: if not self.rewards_source_available: self._compute_rewards_source() return self.rewards_source - + def _compute_rewards_source(self): """ Computes a tensor of length len(self.states) with rewards of the corresponding source states. diff --git a/gflownet/utils/logger.py b/gflownet/utils/logger.py index 0012317e8..03cd4f279 100644 --- a/gflownet/utils/logger.py +++ b/gflownet/utils/logger.py @@ -387,9 +387,6 @@ def save_models( path = self.sf_ckpt_path.parent / stem torch.save(state_flow.model.state_dict(), path) - - - def log_time(self, times: dict, use_context: bool): if self.do.times: times = {"time_{}".format(k): v for k, v in times.items()} diff --git a/gflownet/utils/policy.py b/gflownet/utils/policy.py index 971e8a2a9..973e16c84 100644 --- a/gflownet/utils/policy.py +++ b/gflownet/utils/policy.py @@ -30,4 +30,4 @@ def parse_policy_config(config: DictConfig, kind: str) -> Optional[DictConfig]: del policy_config.backward del policy_config.shared - return policy_config \ No newline at end of file + return policy_config diff --git a/main.py b/main.py index 79ae422f8..6e10cf771 100644 --- a/main.py +++ b/main.py @@ -68,7 +68,7 @@ def main(config): device=config.device, float_precision=config.float_precision, base=forward_policy, - ) + ) else: state_flow = None diff --git a/mila/launch.py b/mila/launch.py index 5e2f5379f..4ca36aef5 100644 --- a/mila/launch.py +++ b/mila/launch.py @@ -7,8 +7,8 @@ from os.path import expandvars from pathlib import Path from textwrap import dedent -from git import Repo +from git import Repo from yaml import safe_load ROOT = Path(__file__).resolve().parent.parent diff --git a/scripts/dav_mp20_stats.py b/scripts/dav_mp20_stats.py index 2b3e7ee5d..3df1c78c9 100644 --- a/scripts/dav_mp20_stats.py +++ b/scripts/dav_mp20_stats.py @@ -20,6 +20,7 @@ from collections import Counter from external.repos.ActiveLearningMaterials.dave.utils.loaders import make_loaders + from gflownet.proxy.crystals.dave import DAVE from gflownet.utils.common import load_gflow_net_from_run_path, resolve_path diff --git a/scripts/fit_lattice_proxy.py b/scripts/fit_lattice_proxy.py index a416a9bc8..83550d8b2 100644 --- a/scripts/fit_lattice_proxy.py +++ b/scripts/fit_lattice_proxy.py @@ -17,7 +17,6 @@ from gflownet.envs.crystals.lattice_parameters import LatticeParameters from gflownet.proxy.crystals.lattice_parameters import PICKLE_PATH - DATASET_PATH = ( Path(__file__).parents[1] / "data" / "crystals" / "matbench_mp_e_form_lp_stats.csv" ) diff --git a/scripts/mp20_matbench_lp_range.py b/scripts/mp20_matbench_lp_range.py index 4d3ec5180..8ae7fdedd 100644 --- a/scripts/mp20_matbench_lp_range.py +++ b/scripts/mp20_matbench_lp_range.py @@ -7,7 +7,6 @@ import numpy as np import pandas as pd - if __name__ == "__main__": mp = pd.read_csv(Path(__file__).parents[1] / "data/crystals/mp20_lp_stats.csv") mb = pd.read_csv( From 9aeec3ed259fb89995db29997895f490bc5b7736 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 27 Nov 2023 11:51:13 -0500 Subject: [PATCH 07/34] Remove get_parent_is_source() because it is used only once and it is simple enough to do in place --- gflownet/utils/batch.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index b6da02691..64d085126 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -531,11 +531,6 @@ def get_parents_indices(self): self._compute_parents() return self.parents_indices - def get_parent_is_source(self): - if self.parents_available is False: - self._compute_parents() - return self.parents_indices == -1 - def _compute_parents(self): """ Obtains the parent (single parent for each state) of all states in the batch. @@ -888,8 +883,8 @@ def _compute_rewards_parents(self): """ state_rewards = self.get_rewards() self.rewards_parents = torch.zeros_like(state_rewards) - parent_is_source = self.get_parent_is_source() parent_indices = self.get_parents_indices() + parent_is_source = parent_indices == -1 self.rewards_parents[~parent_is_source] = self.rewards[ parent_indices[~parent_is_source] ] From 04bd98e4882a9e72826898237440b4b5a5481833 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 27 Nov 2023 13:10:40 -0500 Subject: [PATCH 08/34] Make forwardlooking config in gflownet/ as with other losses and move state_flow inside gflownet --- config/gflownet/forwardlooking.yaml | 9 +++++++++ config/gflownet/gflownet.yaml | 2 ++ config/{ => gflownet}/state_flow/mlp.yaml | 0 config/main.yaml | 1 - config/policy/mlp_forwardlooking.yaml | 7 +++++++ main.py | 10 ++++++---- 6 files changed, 24 insertions(+), 5 deletions(-) create mode 100644 config/gflownet/forwardlooking.yaml rename config/{ => gflownet}/state_flow/mlp.yaml (100%) create mode 100644 config/policy/mlp_forwardlooking.yaml diff --git a/config/gflownet/forwardlooking.yaml b/config/gflownet/forwardlooking.yaml new file mode 100644 index 000000000..c2c641719 --- /dev/null +++ b/config/gflownet/forwardlooking.yaml @@ -0,0 +1,9 @@ +defaults: + - gflownet + - state_flow: mlp + +optimizer: + loss: forwardlooking + lr: 0.0001 + lr_decay_period: 1000000 + lr_decay_gamma: 0.5 diff --git a/config/gflownet/gflownet.yaml b/config/gflownet/gflownet.yaml index 22dd6dd10..c33e52eaf 100644 --- a/config/gflownet/gflownet.yaml +++ b/config/gflownet/gflownet.yaml @@ -34,6 +34,8 @@ optimizer: # From original implementation bootstrap_tau: 0.0 clip_grad_norm: 0.0 +# State flow modelling +state_flow: null # If True, compute rewards in batches batch_reward: True # Force zero probability of sampling invalid actions diff --git a/config/state_flow/mlp.yaml b/config/gflownet/state_flow/mlp.yaml similarity index 100% rename from config/state_flow/mlp.yaml rename to config/gflownet/state_flow/mlp.yaml diff --git a/config/main.yaml b/config/main.yaml index c550bfaf6..7ea98e735 100644 --- a/config/main.yaml +++ b/config/main.yaml @@ -3,7 +3,6 @@ defaults: - env: grid - gflownet: flowmatch - policy: mlp_${gflownet} - - state_flow: null - proxy: corners - logger: wandb - user: alex diff --git a/config/policy/mlp_forwardlooking.yaml b/config/policy/mlp_forwardlooking.yaml new file mode 100644 index 000000000..41f43231e --- /dev/null +++ b/config/policy/mlp_forwardlooking.yaml @@ -0,0 +1,7 @@ +defaults: + - mlp + +backward: + shared_weights: True + checkpoint: null + reload_ckpt: False diff --git a/main.py b/main.py index 6e10cf771..d7aab0c5f 100644 --- a/main.py +++ b/main.py @@ -60,10 +60,10 @@ def main(config): float_precision=config.float_precision, base=forward_policy, ) - - if config.gflownet.optimizer.loss in ["forwardlooking", "fl"]: + # State flow + if config.gflownet.state_flow is not None: state_flow = hydra.utils.instantiate( - config.state_flow, + config.gflownet.state_flow, env=env, device=config.device, float_precision=config.float_precision, @@ -71,7 +71,7 @@ def main(config): ) else: state_flow = None - + # GFlowNet Agent gflownet = hydra.utils.instantiate( config.gflownet, device=config.device, @@ -83,6 +83,8 @@ def main(config): buffer=config.env.buffer, logger=logger, ) + + # Train GFlowNet gflownet.train() # Sample from trained GFlowNet From c7331727270fc704609bf0f0b1f8e46bb04b1d90 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 27 Nov 2023 14:02:40 -0500 Subject: [PATCH 09/34] Remove attribute self.non_terminal_rewards from both GFlowNetAgent and Batch --- gflownet/gflownet.py | 14 +------------- gflownet/utils/batch.py | 38 ++++++++++++++++++++++++-------------- 2 files changed, 25 insertions(+), 27 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 59fce6542..b66f3e91e 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -77,15 +77,12 @@ def __init__( if optimizer.loss in ["flowmatch", "flowmatching"]: self.loss = "flowmatch" self.logZ = None - self.non_terminal_rewards = False elif optimizer.loss in ["trajectorybalance", "tb"]: self.loss = "trajectorybalance" self.logZ = nn.Parameter(torch.ones(optimizer.z_dim) * 150.0 / 64) - self.non_terminal_rewards = False elif optimizer.loss in ["forwardlooking", "fl"]: self.loss = "forwardlooking" self.logZ = None - self.non_terminal_rewards = True else: print("Unkown loss. Using flowmatch as default") self.loss = "flowmatch" @@ -428,7 +425,6 @@ def sample_batch( env=self.env, device=self.device, float_type=self.float, - non_terminal_rewards=self.non_terminal_rewards, ) # ON-POLICY FORWARD trajectories @@ -438,7 +434,6 @@ def sample_batch( env=self.env, device=self.device, float_type=self.float, - non_terminal_rewards=self.non_terminal_rewards, ) while envs: # Sample actions @@ -465,7 +460,6 @@ def sample_batch( env=self.env, device=self.device, float_type=self.float, - non_terminal_rewards=self.non_terminal_rewards, ) if n_train > 0 and self.buffer.train_pkl is not None: with open(self.buffer.train_pkl, "rb") as f: @@ -501,7 +495,6 @@ def sample_batch( env=self.env, device=self.device, float_type=self.float, - non_terminal_rewards=self.non_terminal_rewards, ) if n_replay > 0 and self.buffer.replay_pkl is not None: with open(self.buffer.replay_pkl, "rb") as f: @@ -754,8 +747,7 @@ def forwardlooking_loss(self, it, batch): # Can be optimised by reusing states_log_flflow and batch.get_parent_indices parents_log_flflow = self.state_flow(parents_policy) - assert batch.non_terminal_rewards - rewards_states = batch.get_rewards() + rewards_states = batch.get_rewards(do_non_terminating=True) rewards_parents = batch.get_rewards_parents() energies_states = -torch.log(rewards_states) energies_parents = -torch.log(rewards_parents) @@ -869,7 +861,6 @@ def estimate_logprobs_data( env=self.env, device=self.device, float_type=self.float, - non_terminal_rewards=self.non_terminal_rewards, ) # Create an environment for each data point and trajectory and set the state envs = [] @@ -973,7 +964,6 @@ def train(self): env=self.env, device=self.device, float_type=self.float, - non_terminal_rewards=self.non_terminal_rewards, ) for j in range(self.sttr): sub_batch, times = self.sample_batch( @@ -1268,7 +1258,6 @@ def test_top_k(self, it, progress=False, gfn_states=None, random_states=None): env=self.env, device=self.device, float_type=self.float, - non_terminal_rewards=self.non_terminal_rewards, ) self.random_action_prob = 0 t = time.time() @@ -1296,7 +1285,6 @@ def test_top_k(self, it, progress=False, gfn_states=None, random_states=None): env=self.env, device=self.device, float_type=self.float, - non_terminal_rewards=self.non_terminal_rewards, ) self.random_action_prob = 1.0 print("[test_top_k] Sampling at random...", end="\r") diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index 64d085126..2ae8e0d30 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -39,7 +39,6 @@ def __init__( env: Optional[GFlowNetEnv] = None, device: Union[str, torch.device] = "cpu", float_type: Union[int, torch.dtype] = 32, - non_terminal_rewards: bool = False, ): """ env : GFlowNetEnv @@ -57,8 +56,6 @@ def __init__( self.device = set_device(device) # Float precision self.float = set_float_precision(float_type) - # Whether rewards should be computed for non-terminal states - self.non_terminal_rewards = non_terminal_rewards # Generic environment, properties and dictionary of state and forward mask of # source (as tensor) if env is not None: @@ -837,7 +834,9 @@ def _compute_masks_backward(self): self.masks_backward_available = True def get_rewards( - self, force_recompute: Optional[bool] = False + self, + force_recompute: Optional[bool] = False, + do_non_terminating: Optional[bool] = False, ) -> TensorType["n_states"]: """ Returns the rewards of all states in the batch (including not done). @@ -846,26 +845,37 @@ def get_rewards( ---- force_recompute : bool If True, the rewards are recomputed even if they are available. + + do_non_terminating : bool + If True, compute the rewards of the non-terminating states instead of + assigning reward 0. """ if self.rewards_available is False or force_recompute is True: - self._compute_rewards() + self._compute_rewards(do_non_terminating) return self.rewards - def _compute_rewards(self): + def _compute_rewards(self, do_non_terminating: Optional[bool] = False): """ Computes rewards for all self.states by first converting the states into proxy format. The result is stored in self.rewards as a torch.tensor + + Args + ---- + do_non_terminating : bool + If True, compute the rewards of the non-terminating states instead of + assigning reward 0. """ - self.rewards = torch.zeros(len(self), dtype=self.float, device=self.device) - done = self.get_done() - if self.non_terminal_rewards: + if do_non_terminating: self.rewards = self.env.proxy2reward(self.env.proxy(self.states2proxy())) - elif len(done) > 0: - states_proxy_done = self.get_terminating_states(proxy=True) - self.rewards[done] = self.env.proxy2reward( - self.env.proxy(states_proxy_done) - ) + else: + self.rewards = torch.zeros(len(self), dtype=self.float, device=self.device) + done = self.get_done() + if len(done) > 0: + states_proxy_done = self.get_terminating_states(proxy=True) + self.rewards[done] = self.env.proxy2reward( + self.env.proxy(states_proxy_done) + ) self.rewards_available = True def get_rewards_parents(self) -> TensorType["n_states"]: From d1dcbbb7bfc1c9bf97f29014c067eeb3664e5a6e Mon Sep 17 00:00:00 2001 From: Alexandra Date: Mon, 27 Nov 2023 14:27:12 -0500 Subject: [PATCH 10/34] Update gflownet/gflownet.py Co-authored-by: Alex --- gflownet/gflownet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index b66f3e91e..9cae957a9 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -202,7 +202,7 @@ def parameters(self): raise ValueError(f"State flow cannot be trained in {self.loss} loss.") parameters += list(self.state_flow.model.parameters()) return parameters - + raise ValueError(f"State flow cannot be trained with {self.loss} loss.") def sample_actions( self, envs: List[GFlowNetEnv], From 913bcfb8b3dde3b4ce824036b036dfc8d2191059 Mon Sep 17 00:00:00 2001 From: Alexandra Date: Mon, 27 Nov 2023 14:27:45 -0500 Subject: [PATCH 11/34] Update gflownet/gflownet.py Co-authored-by: Alex --- gflownet/gflownet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 9cae957a9..fa2272fdb 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -195,7 +195,7 @@ def parameters(self): parameters = list(self.forward_policy.model.parameters()) if self.backward_policy.is_model: if self.loss == "flowmatch": - raise ValueError("Backward Policy cannot be a nn in flowmatch.") + raise ValueError("Backward Policy cannot be a model in flowmatch.") parameters += list(self.backward_policy.model.parameters()) if self.state_flow is not None: if self.loss != "forwardlooking": From 9f4dea2add8f06f256caaf99ffedcd6e1ee7a64c Mon Sep 17 00:00:00 2001 From: Alexandra Date: Mon, 27 Nov 2023 14:29:38 -0500 Subject: [PATCH 12/34] Update gflownet/utils/batch.py, iteration over indices Co-authored-by: Alex --- gflownet/utils/batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index 2ae8e0d30..5c935a082 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -554,7 +554,7 @@ def _compute_parents(self): # parent is not source # TODO: check if tensor and sort without iter self.parents.extend([self.states[idx] for idx in batch_indices[:-1]]) - self.parents_indices.extend([idx for idx in batch_indices[:-1]]) + self.parents_indices.extend(batch_indices[:-1]) indices.extend(batch_indices) # Sort parents list in the same order as states # TODO: check if tensor and sort without iter From 1d7a838f1e430e707c93e83ed051e9b3a0210bc8 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Mon, 27 Nov 2023 14:34:42 -0500 Subject: [PATCH 13/34] bug fix --- gflownet/gflownet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index fa2272fdb..3452c5953 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -199,10 +199,10 @@ def parameters(self): parameters += list(self.backward_policy.model.parameters()) if self.state_flow is not None: if self.loss != "forwardlooking": - raise ValueError(f"State flow cannot be trained in {self.loss} loss.") + raise ValueError(f"State flow cannot be trained with {self.loss} loss.") parameters += list(self.state_flow.model.parameters()) return parameters - raise ValueError(f"State flow cannot be trained with {self.loss} loss.") + def sample_actions( self, envs: List[GFlowNetEnv], From 35f1361f4845b8d40ed032e345d52f427c46e278 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 27 Nov 2023 14:36:52 -0500 Subject: [PATCH 14/34] unblack --- gflownet/gflownet.py | 48 ++++++++------------------------------------ 1 file changed, 8 insertions(+), 40 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 3452c5953..7372a215b 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -421,20 +421,12 @@ def sample_batch( "actions_envs": 0.0, } t0_all = time.time() - batch = Batch( - env=self.env, - device=self.device, - float_type=self.float, - ) + batch = Batch(env=self.env, device=self.device, float_type=self.float) # ON-POLICY FORWARD trajectories t0_forward = time.time() envs = [self.env.copy().reset(idx) for idx in range(n_forward)] - batch_forward = Batch( - env=self.env, - device=self.device, - float_type=self.float, - ) + batch_forward = Batch(env=self.env, device=self.device, float_type=self.float) while envs: # Sample actions t0_a_envs = time.time() @@ -456,11 +448,7 @@ def sample_batch( # TRAIN BACKWARD trajectories t0_train = time.time() envs = [self.env.copy().reset(idx) for idx in range(n_train)] - batch_train = Batch( - env=self.env, - device=self.device, - float_type=self.float, - ) + batch_train = Batch(env=self.env, device=self.device, float_type=self.float) if n_train > 0 and self.buffer.train_pkl is not None: with open(self.buffer.train_pkl, "rb") as f: dict_tr = pickle.load(f) @@ -491,11 +479,7 @@ def sample_batch( # REPLAY BACKWARD trajectories t0_replay = time.time() - batch_replay = Batch( - env=self.env, - device=self.device, - float_type=self.float, - ) + batch_replay = Batch(env=self.env, device=self.device, float_type=self.float) if n_replay > 0 and self.buffer.replay_pkl is not None: with open(self.buffer.replay_pkl, "rb") as f: dict_replay = pickle.load(f) @@ -857,11 +841,7 @@ def estimate_logprobs_data( end_batch = min(batch_size, n_states) pbar = tqdm(total=n_states) while init_batch < n_states: - batch = Batch( - env=self.env, - device=self.device, - float_type=self.float, - ) + batch = Batch(env=self.env, device=self.device, float_type=self.float) # Create an environment for each data point and trajectory and set the state envs = [] for state_idx in range(init_batch, end_batch): @@ -960,11 +940,7 @@ def train(self): self.logger.log_metrics(metrics, use_context=self.use_context, step=it) self.logger.log_summary(summary) t0_iter = time.time() - batch = Batch( - env=self.env, - device=self.device, - float_type=self.float, - ) + batch = Batch(env=self.env, device=self.device, float_type=self.float) for j in range(self.sttr): sub_batch, times = self.sample_batch( n_forward=self.batch_size.forward, @@ -1254,11 +1230,7 @@ def test_top_k(self, it, progress=False, gfn_states=None, random_states=None): print() if not gfn_states: # sample states from the current gfn - batch = Batch( - env=self.env, - device=self.device, - float_type=self.float, - ) + batch = Batch(env=self.env, device=self.device, float_type=self.float) self.random_action_prob = 0 t = time.time() print("Sampling from GFN...", end="\r") @@ -1281,11 +1253,7 @@ def test_top_k(self, it, progress=False, gfn_states=None, random_states=None): if do_random: # sample random states from uniform actions if not random_states: - batch = Batch( - env=self.env, - device=self.device, - float_type=self.float, - ) + batch = Batch(env=self.env, device=self.device, float_type=self.float) self.random_action_prob = 1.0 print("[test_top_k] Sampling at random...", end="\r") for b in batch_with_rest( From 1907e708f6bc8fbd0479a2460dea7f21805cde5f Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 27 Nov 2023 16:07:52 -0500 Subject: [PATCH 15/34] Add do_non_terminating=True to get_rewards in getting parents rewards --- gflownet/utils/batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index 5c935a082..5ef4cb1eb 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -891,7 +891,7 @@ def _compute_rewards_parents(self): Computes rewards of the self.parents by reusing rewards of the states (i.e. self.rewards). Stores the result in self.rewards_parents """ - state_rewards = self.get_rewards() + state_rewards = self.get_rewards(do_non_terminating=True) self.rewards_parents = torch.zeros_like(state_rewards) parent_indices = self.get_parents_indices() parent_is_source = parent_indices == -1 From 5a135d9cae55eeedcffcef0a16ec1f651d7fc725 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 27 Nov 2023 16:44:20 -0500 Subject: [PATCH 16/34] Implementation of the detailed balance loss and the necessary configuration files --- config/gflownet/detailedbalance.yaml | 9 ++++ config/policy/mlp_detailedbalance.yaml | 7 +++ gflownet/gflownet.py | 66 +++++++++++++++++++++++++- 3 files changed, 81 insertions(+), 1 deletion(-) create mode 100644 config/gflownet/detailedbalance.yaml create mode 100644 config/policy/mlp_detailedbalance.yaml diff --git a/config/gflownet/detailedbalance.yaml b/config/gflownet/detailedbalance.yaml new file mode 100644 index 000000000..073ae99f4 --- /dev/null +++ b/config/gflownet/detailedbalance.yaml @@ -0,0 +1,9 @@ +defaults: + - gflownet + - state_flow: mlp + +optimizer: + loss: detailedbalance + lr: 0.0001 + lr_decay_period: 1000000 + lr_decay_gamma: 0.5 diff --git a/config/policy/mlp_detailedbalance.yaml b/config/policy/mlp_detailedbalance.yaml new file mode 100644 index 000000000..41f43231e --- /dev/null +++ b/config/policy/mlp_detailedbalance.yaml @@ -0,0 +1,7 @@ +defaults: + - mlp + +backward: + shared_weights: True + checkpoint: null + reload_ckpt: False diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 7372a215b..09934949d 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -80,6 +80,9 @@ def __init__( elif optimizer.loss in ["trajectorybalance", "tb"]: self.loss = "trajectorybalance" self.logZ = nn.Parameter(torch.ones(optimizer.z_dim) * 150.0 / 64) + elif optimizer.loss in ["detailedbalance", "db"]: + self.loss = "detailedbalance" + self.logZ = None elif optimizer.loss in ["forwardlooking", "fl"]: self.loss = "forwardlooking" self.logZ = None @@ -198,7 +201,7 @@ def parameters(self): raise ValueError("Backward Policy cannot be a model in flowmatch.") parameters += list(self.backward_policy.model.parameters()) if self.state_flow is not None: - if self.loss != "forwardlooking": + if self.loss not in ["detailedbalance", "forwardlooking"]: raise ValueError(f"State flow cannot be trained with {self.loss} loss.") parameters += list(self.state_flow.model.parameters()) return parameters @@ -679,6 +682,65 @@ def trajectorybalance_loss(self, it, batch): ) return loss, loss, loss + def detailedbalance_loss(self, it, batch): + """ + Computes the Detailed Balance GFlowNet loss of a batch + Reference : https://arxiv.org/pdf/2201.13259.pdf (eq 11) + + Args + ---- + it : int + Iteration + + batch : Batch + A batch of data, containing all the states in the trajectories. + + + Returns + ------- + loss : float + + term_loss : float + Loss of the terminal nodes only + + nonterm_loss : float + Loss of the intermediate nodes only + """ + + assert batch.is_valid() + # Get necessary tensors from batch + states = batch.get_states(policy=False) + states_policy = batch.get_states(policy=True) + actions = batch.get_actions() + parents = batch.get_parents(policy=False) + parents_policy = batch.get_parents(policy=True) + done = batch.get_done() + rewards = batch.get_terminating_rewards(sort_by="insertion") + + # Get logprobs + masks_f = batch.get_masks_forward(of_parents=True) + policy_output_f = self.forward_policy(parents_policy) + logprobs_f = self.env.get_logprobs( + policy_output_f, actions, masks_f, parents, is_backward=False + ) + masks_b = batch.get_masks_backward() + policy_output_b = self.backward_policy(states_policy) + logprobs_b = self.env.get_logprobs( + policy_output_b, actions, masks_b, states, is_backward=True + ) + + # Get logflows + logflow_states = self.state_flow(states_policy).squeeze() + logflow_states[done.eq(1)] = rewards + # TODO: Optimise by reusing logflow_states and batch.get_parent_indices + logflow_parents = self.state_flow(parents_policy).squeeze() + + # Detailed balance loss + loss = ( + (logflow_parents + logprobs_f - logflow_states - logprobs_b).pow(2).mean() + ) + return loss, loss, loss + def forwardlooking_loss(self, it, batch): """ Computes the Forward-Looking GFlowNet loss of a batch @@ -957,6 +1019,8 @@ def train(self): losses = self.trajectorybalance_loss( it * self.ttsr + j, batch ) # returns (opt loss, *metrics) + elif self.loss == "detailedbalance": + losses = self.detailedbalance_loss(it * self.ttsr + j, batch) elif self.loss == "forwardlooking": losses = self.forwardlooking_loss(it * self.ttsr + j, batch) else: From f27a8262ea8f17dcf1ebf582de86330be0d26842 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 27 Nov 2023 17:37:41 -0500 Subject: [PATCH 17/34] Fix: logflows of terminating is log(rewards) --- gflownet/gflownet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 09934949d..abbd82053 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -731,7 +731,7 @@ def detailedbalance_loss(self, it, batch): # Get logflows logflow_states = self.state_flow(states_policy).squeeze() - logflow_states[done.eq(1)] = rewards + logflow_states[done.eq(1)] = torch.log(rewards) # TODO: Optimise by reusing logflow_states and batch.get_parent_indices logflow_parents = self.state_flow(parents_policy).squeeze() From 9771adba032b872e1d70ea123cdb374657110ba2 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Mon, 27 Nov 2023 17:53:35 -0500 Subject: [PATCH 18/34] tests, docstrings, doo_non_terminating in get_parents_rewards --- gflownet/envs/base.py | 4 +- gflownet/gflownet.py | 2 +- gflownet/utils/batch.py | 29 +++++- tests/gflownet/utils/test_batch.py | 153 +++++++++++++++++++++++++++++ 4 files changed, 180 insertions(+), 8 deletions(-) diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index 9b5e81f3a..e1cae5b86 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -780,13 +780,13 @@ def traj2readable(self, traj=None): """ return str(traj).replace("(", "[").replace(")", "]").replace(",", "") - def reward(self, state=None, done=None): + def reward(self, state=None, done=None, do_non_terminating=False): """ Computes the reward of a state """ state = self._get_state(state) done = self._get_done(done) - if done is False: + if done is False and do_non_terminating is False: return tfloat(0.0, float_type=self.float, device=self.device) return self.proxy2reward(self.proxy(self.state2proxy(state))[0]) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 7372a215b..1b48000c5 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -732,7 +732,7 @@ def forwardlooking_loss(self, it, batch): parents_log_flflow = self.state_flow(parents_policy) rewards_states = batch.get_rewards(do_non_terminating=True) - rewards_parents = batch.get_rewards_parents() + rewards_parents = batch.get_rewards_parents(do_non_terminating=True) energies_states = -torch.log(rewards_states) energies_parents = -torch.log(rewards_parents) diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index 5c935a082..9f7ed5d10 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -524,13 +524,18 @@ def get_parents( return self.parents def get_parents_indices(self): + """ + Returns indices of the parents of the states in the batch. + Each index corresponds to the position of the patent in the self.states tensor, if it is peresent there. + If a parent is not present in self.states (i.e. it is source), the corresponding index is -1 + """ if self.parents_available is False: self._compute_parents() return self.parents_indices def _compute_parents(self): """ - Obtains the parent (single parent for each state) of all states in the batch. + Obtains the parent (single parent for each state) of all states in the batch and its index. The parents are computed, obtaining all necessary components, if they are not readily available. Missing components and newly computed components are added to the batch (self.component is set). The following variable is stored: @@ -539,6 +544,8 @@ def _compute_parents(self): as self.states (list of lists or tensor) Length: n_states Shape: [n_states, state_dims] + - self.parents_indices: the position of each parent in self.states tensor. + If a parent is not present in self.states (i.e. it is source), the corresponding index is -1 self.parents_available is set to True. """ @@ -878,20 +885,32 @@ def _compute_rewards(self, do_non_terminating: Optional[bool] = False): ) self.rewards_available = True - def get_rewards_parents(self) -> TensorType["n_states"]: + def get_rewards_parents(self, do_non_terminating=False) -> TensorType["n_states"]: """ Returns the rewards of all parents in the batch + + Args + ---- + do_non_terminating : bool + If True, compute the rewards of the non-terminating states instead of + assigning reward 0. """ if not self.rewards_parents_available: - self._compute_rewards_parents() + self._compute_rewards_parents(do_non_terminating=do_non_terminating) return self.rewards_parents - def _compute_rewards_parents(self): + def _compute_rewards_parents(self, do_non_terminating=False): """ Computes rewards of the self.parents by reusing rewards of the states (i.e. self.rewards). Stores the result in self.rewards_parents + + Args + ---- + do_non_terminating : bool + If True, compute the rewards of the non-terminating states instead of + assigning reward 0. """ - state_rewards = self.get_rewards() + state_rewards = self.get_rewards(do_non_terminating=do_non_terminating) self.rewards_parents = torch.zeros_like(state_rewards) parent_indices = self.get_parents_indices() parent_is_source = parent_indices == -1 diff --git a/tests/gflownet/utils/test_batch.py b/tests/gflownet/utils/test_batch.py index 338dfd061..910066a62 100644 --- a/tests/gflownet/utils/test_batch.py +++ b/tests/gflownet/utils/test_batch.py @@ -71,6 +71,10 @@ def corners(): def tetris_score(): return TetrisScore(device="cpu", float_precision=32, normalize=False) +@pytest.fixture() +def tetris_score_norm(): + return TetrisScore(device="cpu", float_precision=32, normalize=True) + # @pytest.mark.skip(reason="skip while developping other tests") def test__len__returnszero_at_init(batch): @@ -1325,3 +1329,152 @@ def test__make_indices_consecutive__multiplied_indices_become_consecutive( assert torch.equal( traj_indices_batch, tlong(traj_indices_consecutive, device=batch.device) ) + +@pytest.mark.repeat(N_REPETITIONS) +@pytest.mark.parametrize( + "env, proxy", + [("grid2d", "corners"), ("tetris6x4", "tetris_score"), ("ctorus2d5l", "corners")], +) +# @pytest.mark.skip(reason="skip while developping other tests") +def test__get_rewards__single_env_returns_expected_non_terminal(env, proxy, batch, request): + env = request.getfixturevalue(env) + proxy = request.getfixturevalue(proxy) + env = env.reset() + env.proxy = proxy + env.setup_proxy() + batch.set_env(env) + + rewards = [] + while not env.done: + parent = env.state + # Sample random action + _, action, valid = env.step_random() + # Add to batch + batch.add_to_batch([env], [action], [valid]) + if valid: + rewards.append(env.reward(do_non_terminating=True)) + rewards_batch = batch.get_rewards(do_non_terminating=True) + rewards = torch.stack(rewards) + assert torch.equal( + rewards_batch, + tfloat(rewards, device=batch.device, float_type=batch.float), + ), (rewards, rewards_batch) + + +@pytest.mark.repeat(N_REPETITIONS) +# @pytest.mark.skip(reason="skip while developping other tests") +@pytest.mark.parametrize( + "env, proxy", + [("grid2d", "corners"), ("tetris6x4", "tetris_score_norm")], +) +def test__get_rewards_multiple_env_returns_expected_non_zero_non_terminal(env, proxy, batch, request): + batch_size = BATCH_SIZE + env_ref = request.getfixturevalue(env) + proxy = request.getfixturevalue(proxy) + env_ref = env_ref.reset() + env_ref.proxy = proxy + env_ref.setup_proxy() + env_ref.reward_func = 'boltzmann' + + batch.set_env(env_ref) + + # Make list of envs + envs = [] + for idx in range(batch_size): + env_aux = env_ref.copy().reset(idx) + envs.append(env_aux) + + rewards = [] + proxy_values = [] + + + # Iterate until envs is empty + while envs: + actions_iter = [] + valids_iter = [] + # Make step env by env (different to GFN Agent) to have full control + for env in envs: + parent = copy(env.state) + # Sample random action + state, action, valid = env.step_random() + if valid: + # Add to iter lists + actions_iter.append(action) + valids_iter.append(valid) + rewards.append(env.reward(do_non_terminating=True)) + proxy_values.append(env.proxy(env.state2proxy(env.state))[0]) + # Add all envs, actions and valids to batch + batch.add_to_batch(envs, actions_iter, valids_iter) + # Remove done envs + envs = [env for env in envs if not env.done] + + rewards_batch = batch.get_rewards(do_non_terminating=True) + rewards = torch.stack(rewards) + assert torch.equal( + rewards_batch, + tfloat(rewards, device=batch.device, float_type=batch.float), + ), (rewards, rewards_batch) + assert ~torch.any(torch.isclose(rewards_batch, torch.zeros_like(rewards_batch))), rewards_batch + +@pytest.mark.repeat(N_REPETITIONS) +# @pytest.mark.skip(reason="skip while developping other tests") +@pytest.mark.parametrize( + "env, proxy", + [("grid2d", "corners"), ("tetris6x4", "tetris_score_norm"), ("ctorus2d5l", "corners")], +) +def test__get_rewards_parents_multiple_env_returns_expected_non_terminal(env, proxy, batch, request): + batch_size = BATCH_SIZE + env_ref = request.getfixturevalue(env) + proxy = request.getfixturevalue(proxy) + env_ref = env_ref.reset() + env_ref.proxy = proxy + env_ref.setup_proxy() + + batch.set_env(env_ref) + + # Make list of envs + envs = [] + for idx in range(batch_size): + env_aux = env_ref.copy().reset(idx) + envs.append(env_aux) + + rewards_parents = [] + rewards = [] + + # Iterate until envs is empty + while envs: + actions_iter = [] + valids_iter = [] + # Make step env by env (different to GFN Agent) to have full control + for env in envs: + parent = copy(env.state) + done_parent = env.done + + # Sample random action + state, action, valid = env.step_random() + if valid: + # Add to iter lists + actions_iter.append(action) + valids_iter.append(valid) + rewards_parents.append(env.reward(state=parent, done=done_parent, do_non_terminating=True)) + rewards.append(env.reward(do_non_terminating=True)) + # Add all envs, actions and valids to batch + batch.add_to_batch(envs, actions_iter, valids_iter) + # Remove done envs + envs = [env for env in envs if not env.done] + + rewards_parents_batch = batch.get_rewards_parents(do_non_terminating=True) + rewards_parents = torch.stack(rewards_parents) + + rewards_batch = batch.get_rewards(do_non_terminating=True) + rewards = torch.stack(rewards) + + assert torch.equal( + rewards_parents_batch, + tfloat(rewards_parents, device=batch.device, float_type=batch.float), + ), (rewards_parents, rewards_parents_batch) + + assert torch.equal( + rewards_batch, + tfloat(rewards, device=batch.device, float_type=batch.float), + ), (rewards, rewards_batch) \ No newline at end of file From 391c304857628cfff849aba040e08529f8039ab9 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 27 Nov 2023 17:54:29 -0500 Subject: [PATCH 19/34] Fix bug in FL loss (logflows needed to be squeezed); now it seems to work like a charm --- gflownet/gflownet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 7372a215b..e686a0b0d 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -725,11 +725,11 @@ def forwardlooking_loss(self, it, batch): policy_output_f, actions, masks_f, parents, is_backward=False ) - states_log_flflow = self.state_flow(states_policy) + states_log_flflow = self.state_flow(states_policy).squeeze() # forward-looking flow is 1 in the terminal states states_log_flflow[done.eq(1)] = 0.0 # Can be optimised by reusing states_log_flflow and batch.get_parent_indices - parents_log_flflow = self.state_flow(parents_policy) + parents_log_flflow = self.state_flow(parents_policy).squeeze() rewards_states = batch.get_rewards(do_non_terminating=True) rewards_parents = batch.get_rewards_parents() From c0f4ef45359f8a55d074e4c080a884ce53285f8a Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Mon, 27 Nov 2023 18:05:10 -0500 Subject: [PATCH 20/34] isort, black --- gflownet/utils/batch.py | 8 +++---- tests/gflownet/utils/test_batch.py | 34 ++++++++++++++++++++++-------- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index e44b58da9..f71dff505 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -525,7 +525,7 @@ def get_parents( def get_parents_indices(self): """ - Returns indices of the parents of the states in the batch. + Returns indices of the parents of the states in the batch. Each index corresponds to the position of the patent in the self.states tensor, if it is peresent there. If a parent is not present in self.states (i.e. it is source), the corresponding index is -1 """ @@ -544,8 +544,8 @@ def _compute_parents(self): as self.states (list of lists or tensor) Length: n_states Shape: [n_states, state_dims] - - self.parents_indices: the position of each parent in self.states tensor. - If a parent is not present in self.states (i.e. it is source), the corresponding index is -1 + - self.parents_indices: the position of each parent in self.states tensor. + If a parent is not present in self.states (i.e. it is source), the corresponding index is -1 self.parents_available is set to True. """ @@ -898,7 +898,7 @@ def _compute_rewards_parents(self): Computes rewards of the self.parents by reusing rewards of the states (i.e. self.rewards). Stores the result in self.rewards_parents """ - # TODO: this may return zero rewards for all parents if before + # TODO: this may return zero rewards for all parents if before # rewards for states were computed with do_non_terminating=False state_rewards = self.get_rewards(do_non_terminating=True) self.rewards_parents = torch.zeros_like(state_rewards) diff --git a/tests/gflownet/utils/test_batch.py b/tests/gflownet/utils/test_batch.py index 356b00c57..49d0b07ba 100644 --- a/tests/gflownet/utils/test_batch.py +++ b/tests/gflownet/utils/test_batch.py @@ -71,6 +71,7 @@ def corners(): def tetris_score(): return TetrisScore(device="cpu", float_precision=32, normalize=False) + @pytest.fixture() def tetris_score_norm(): return TetrisScore(device="cpu", float_precision=32, normalize=True) @@ -1330,13 +1331,16 @@ def test__make_indices_consecutive__multiplied_indices_become_consecutive( traj_indices_batch, tlong(traj_indices_consecutive, device=batch.device) ) + @pytest.mark.repeat(N_REPETITIONS) @pytest.mark.parametrize( "env, proxy", [("grid2d", "corners"), ("tetris6x4", "tetris_score"), ("ctorus2d5l", "corners")], ) # @pytest.mark.skip(reason="skip while developping other tests") -def test__get_rewards__single_env_returns_expected_non_terminal(env, proxy, batch, request): +def test__get_rewards__single_env_returns_expected_non_terminal( + env, proxy, batch, request +): env = request.getfixturevalue(env) proxy = request.getfixturevalue(proxy) env = env.reset() @@ -1367,14 +1371,16 @@ def test__get_rewards__single_env_returns_expected_non_terminal(env, proxy, batc "env, proxy", [("grid2d", "corners"), ("tetris6x4", "tetris_score_norm")], ) -def test__get_rewards_multiple_env_returns_expected_non_zero_non_terminal(env, proxy, batch, request): +def test__get_rewards_multiple_env_returns_expected_non_zero_non_terminal( + env, proxy, batch, request +): batch_size = BATCH_SIZE env_ref = request.getfixturevalue(env) proxy = request.getfixturevalue(proxy) env_ref = env_ref.reset() env_ref.proxy = proxy env_ref.setup_proxy() - env_ref.reward_func = 'boltzmann' + env_ref.reward_func = "boltzmann" batch.set_env(env_ref) @@ -1387,7 +1393,6 @@ def test__get_rewards_multiple_env_returns_expected_non_zero_non_terminal(env, p rewards = [] proxy_values = [] - # Iterate until envs is empty while envs: actions_iter = [] @@ -1414,15 +1419,24 @@ def test__get_rewards_multiple_env_returns_expected_non_zero_non_terminal(env, p rewards_batch, tfloat(rewards, device=batch.device, float_type=batch.float), ), (rewards, rewards_batch) - assert ~torch.any(torch.isclose(rewards_batch, torch.zeros_like(rewards_batch))), rewards_batch + assert ~torch.any( + torch.isclose(rewards_batch, torch.zeros_like(rewards_batch)) + ), rewards_batch + @pytest.mark.repeat(N_REPETITIONS) # @pytest.mark.skip(reason="skip while developping other tests") @pytest.mark.parametrize( "env, proxy", - [("grid2d", "corners"), ("tetris6x4", "tetris_score_norm"), ("ctorus2d5l", "corners")], + [ + ("grid2d", "corners"), + ("tetris6x4", "tetris_score_norm"), + ("ctorus2d5l", "corners"), + ], ) -def test__get_rewards_parents_multiple_env_returns_expected_non_terminal(env, proxy, batch, request): +def test__get_rewards_parents_multiple_env_returns_expected_non_terminal( + env, proxy, batch, request +): batch_size = BATCH_SIZE env_ref = request.getfixturevalue(env) proxy = request.getfixturevalue(proxy) @@ -1456,7 +1470,9 @@ def test__get_rewards_parents_multiple_env_returns_expected_non_terminal(env, pr # Add to iter lists actions_iter.append(action) valids_iter.append(valid) - rewards_parents.append(env.reward(state=parent, done=done_parent, do_non_terminating=True)) + rewards_parents.append( + env.reward(state=parent, done=done_parent, do_non_terminating=True) + ) rewards.append(env.reward(do_non_terminating=True)) # Add all envs, actions and valids to batch batch.add_to_batch(envs, actions_iter, valids_iter) @@ -1477,4 +1493,4 @@ def test__get_rewards_parents_multiple_env_returns_expected_non_terminal(env, pr assert torch.equal( rewards_batch, tfloat(rewards, device=batch.device, float_type=batch.float), - ), (rewards, rewards_batch) \ No newline at end of file + ), (rewards, rewards_batch) From 741f305288bb887bef32464d05eb869b6c660843 Mon Sep 17 00:00:00 2001 From: AlexandraVolokhova Date: Mon, 27 Nov 2023 18:24:35 -0500 Subject: [PATCH 21/34] move squeeze to state flow call --- gflownet/gflownet.py | 4 ++-- gflownet/policy/state_flow.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index e686a0b0d..7372a215b 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -725,11 +725,11 @@ def forwardlooking_loss(self, it, batch): policy_output_f, actions, masks_f, parents, is_backward=False ) - states_log_flflow = self.state_flow(states_policy).squeeze() + states_log_flflow = self.state_flow(states_policy) # forward-looking flow is 1 in the terminal states states_log_flflow[done.eq(1)] = 0.0 # Can be optimised by reusing states_log_flflow and batch.get_parent_indices - parents_log_flflow = self.state_flow(parents_policy).squeeze() + parents_log_flflow = self.state_flow(parents_policy) rewards_states = batch.get_rewards(do_non_terminating=True) rewards_parents = batch.get_rewards_parents() diff --git a/gflownet/policy/state_flow.py b/gflownet/policy/state_flow.py index 8542c8a15..35917a284 100644 --- a/gflownet/policy/state_flow.py +++ b/gflownet/policy/state_flow.py @@ -25,3 +25,9 @@ def instantiate(self): self.is_model = True else: raise "StateFlow model type not defined" + + def __call__(self, states): + """ + Returns a tensor of the state flows of the shape (batch_size, ) + """ + return super().__call__(states).squeeze() From 993b9de3a28497512ad255ba8a49f30cc1ff2d85 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 27 Nov 2023 18:24:43 -0500 Subject: [PATCH 22/34] terminal -> terminating (for consistency) --- tests/gflownet/utils/test_batch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/gflownet/utils/test_batch.py b/tests/gflownet/utils/test_batch.py index 49d0b07ba..2ccd85a58 100644 --- a/tests/gflownet/utils/test_batch.py +++ b/tests/gflownet/utils/test_batch.py @@ -1338,7 +1338,7 @@ def test__make_indices_consecutive__multiplied_indices_become_consecutive( [("grid2d", "corners"), ("tetris6x4", "tetris_score"), ("ctorus2d5l", "corners")], ) # @pytest.mark.skip(reason="skip while developping other tests") -def test__get_rewards__single_env_returns_expected_non_terminal( +def test__get_rewards__single_env_returns_expected_non_terminating( env, proxy, batch, request ): env = request.getfixturevalue(env) @@ -1371,7 +1371,7 @@ def test__get_rewards__single_env_returns_expected_non_terminal( "env, proxy", [("grid2d", "corners"), ("tetris6x4", "tetris_score_norm")], ) -def test__get_rewards_multiple_env_returns_expected_non_zero_non_terminal( +def test__get_rewards_multiple_env_returns_expected_non_zero_non_terminating( env, proxy, batch, request ): batch_size = BATCH_SIZE @@ -1434,7 +1434,7 @@ def test__get_rewards_multiple_env_returns_expected_non_zero_non_terminal( ("ctorus2d5l", "corners"), ], ) -def test__get_rewards_parents_multiple_env_returns_expected_non_terminal( +def test__get_rewards_parents_multiple_env_returns_expected_non_terminating( env, proxy, batch, request ): batch_size = BATCH_SIZE From 06a6fb5d94ee67cbc771b687db61fd3b03872725 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 27 Nov 2023 18:36:23 -0500 Subject: [PATCH 23/34] Extend docstring and wrap docstring lines --- gflownet/utils/batch.py | 48 ++++++++++++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index f71dff505..ca84f9198 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -525,9 +525,16 @@ def get_parents( def get_parents_indices(self): """ - Returns indices of the parents of the states in the batch. - Each index corresponds to the position of the patent in the self.states tensor, if it is peresent there. - If a parent is not present in self.states (i.e. it is source), the corresponding index is -1 + Returns the indices of the parents of the states in the batch. + + Each item idx in the returned list corresponds to the index in self.states that + contains the parent of self.states[idx], if it is peresent there. If a parent + is not present in self.states (because it is the source), the index is -1. + + Returns + ------- + self.parents_indices + The indices in self.states of the parents of self.states. """ if self.parents_available is False: self._compute_parents() @@ -535,17 +542,20 @@ def get_parents_indices(self): def _compute_parents(self): """ - Obtains the parent (single parent for each state) of all states in the batch and its index. + Obtains the parent (single parent for each state) of all states in the batch + and its index. + The parents are computed, obtaining all necessary components, if they are not readily available. Missing components and newly computed components are added - to the batch (self.component is set). The following variable is stored: + to the batch (self.component is set). The following variables are stored: - self.parents: the parent of each state in the batch. It will be the same type as self.states (list of lists or tensor) Length: n_states Shape: [n_states, state_dims] - - self.parents_indices: the position of each parent in self.states tensor. - If a parent is not present in self.states (i.e. it is source), the corresponding index is -1 + - self.parents_indices: the position of each parent in self.states tensor. If a + parent is not present in self.states (i.e. it is source), the corresponding + index is -1. self.parents_available is set to True. """ @@ -887,7 +897,12 @@ def _compute_rewards(self, do_non_terminating: Optional[bool] = False): def get_rewards_parents(self) -> TensorType["n_states"]: """ - Returns the rewards of all parents in the batch + Returns the rewards of all parents in the batch. + + Returns + ------- + self.rewards_parents + A tensor containing the rewards of the parents of self.states. """ if not self.rewards_parents_available: self._compute_rewards_parents() @@ -895,8 +910,10 @@ def get_rewards_parents(self) -> TensorType["n_states"]: def _compute_rewards_parents(self): """ - Computes rewards of the self.parents by reusing rewards of the states (i.e. self.rewards). - Stores the result in self.rewards_parents + Computes the rewards of self.parents by reusing the rewards of the states + (self.rewards). + + Stores the result in self.rewards_parents. """ # TODO: this may return zero rewards for all parents if before # rewards for states were computed with do_non_terminating=False @@ -914,6 +931,11 @@ def _compute_rewards_parents(self): def get_rewards_source(self) -> TensorType["n_states"]: """ Returns rewards of the corresponding source states for each state in the batch. + + Returns + ------- + self.rewards_source + A tensor containing the rewards the source states. """ if not self.rewards_source_available: self._compute_rewards_source() @@ -921,8 +943,10 @@ def get_rewards_source(self) -> TensorType["n_states"]: def _compute_rewards_source(self): """ - Computes a tensor of length len(self.states) with rewards of the corresponding source states. - Stores the result in self.rewards_source + Computes a tensor of length len(self.states) with the rewards of the + corresponding source states. + + Stores the result in self.rewards_source. """ # This will not work if source is randomised if not self.conditional: From 6141b57b1d61c26b2a3e341f8fcdedd7f61a49e3 Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 28 Nov 2023 01:05:15 +0100 Subject: [PATCH 24/34] Edit docstring Co-authored-by: Alexandra --- gflownet/utils/batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index ca84f9198..2e39ff0bb 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -527,7 +527,7 @@ def get_parents_indices(self): """ Returns the indices of the parents of the states in the batch. - Each item idx in the returned list corresponds to the index in self.states that + Each i-th item in the returned list contains the index in self.states that contains the parent of self.states[idx], if it is peresent there. If a parent is not present in self.states (because it is the source), the index is -1. From 660d5eff9b13b5361293571195be30d3dbc38b1d Mon Sep 17 00:00:00 2001 From: Alex Date: Tue, 28 Nov 2023 01:05:23 +0100 Subject: [PATCH 25/34] Edit docstring Co-authored-by: Alexandra --- gflownet/utils/batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index 2e39ff0bb..8f914af04 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -528,7 +528,7 @@ def get_parents_indices(self): Returns the indices of the parents of the states in the batch. Each i-th item in the returned list contains the index in self.states that - contains the parent of self.states[idx], if it is peresent there. If a parent + contains the parent of self.states[i], if it is present there. If a parent is not present in self.states (because it is the source), the index is -1. Returns From 293010f29fb754404587f0d8f60091bd2d0c43c0 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 27 Nov 2023 19:08:14 -0500 Subject: [PATCH 26/34] Compute DB loss on terminating and intermediate states --- gflownet/gflownet.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index abbd82053..b8ed7d33e 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -736,10 +736,11 @@ def detailedbalance_loss(self, it, batch): logflow_parents = self.state_flow(parents_policy).squeeze() # Detailed balance loss - loss = ( - (logflow_parents + logprobs_f - logflow_states - logprobs_b).pow(2).mean() - ) - return loss, loss, loss + loss_all = (logflow_parents + logprobs_f - logflow_states - logprobs_b).pow(2) + loss = loss_all.mean() + loss_terminating = loss_all[done].mean() + loss_intermediate = loss_all[~done].mean() + return loss, loss_terminating, loss_intermediate def forwardlooking_loss(self, it, batch): """ From 798b5eff3ab7eff23018d0941fb0951a35faf935 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 27 Nov 2023 19:09:26 -0500 Subject: [PATCH 27/34] Remove squeeze in DB loss --- gflownet/gflownet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index b8ed7d33e..844b560ea 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -730,10 +730,10 @@ def detailedbalance_loss(self, it, batch): ) # Get logflows - logflow_states = self.state_flow(states_policy).squeeze() + logflow_states = self.state_flow(states_policy) logflow_states[done.eq(1)] = torch.log(rewards) # TODO: Optimise by reusing logflow_states and batch.get_parent_indices - logflow_parents = self.state_flow(parents_policy).squeeze() + logflow_parents = self.state_flow(parents_policy) # Detailed balance loss loss_all = (logflow_parents + logprobs_f - logflow_states - logprobs_b).pow(2) From 20285ffaf52d489311f36597335033d6ac4ad36e Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 27 Nov 2023 19:33:11 -0500 Subject: [PATCH 28/34] Adapt names of variables for consistency --- gflownet/gflownet.py | 71 ++++++++++++++++++++++---------------------- 1 file changed, 35 insertions(+), 36 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 844b560ea..880189cdd 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -730,13 +730,13 @@ def detailedbalance_loss(self, it, batch): ) # Get logflows - logflow_states = self.state_flow(states_policy) - logflow_states[done.eq(1)] = torch.log(rewards) - # TODO: Optimise by reusing logflow_states and batch.get_parent_indices - logflow_parents = self.state_flow(parents_policy) + logflows_states = self.state_flow(states_policy) + logflows_states[done.eq(1)] = torch.log(rewards) + # TODO: Optimise by reusing logflows_states and batch.get_parent_indices + logflows_parents = self.state_flow(parents_policy) # Detailed balance loss - loss_all = (logflow_parents + logprobs_f - logflow_states - logprobs_b).pow(2) + loss_all = (logflows_parents + logprobs_f - logflows_states - logprobs_b).pow(2) loss = loss_all.mean() loss_terminating = loss_all[done].mean() loss_intermediate = loss_all[~done].mean() @@ -769,49 +769,48 @@ def forwardlooking_loss(self, it, batch): assert batch.is_valid() # Get necessary tensors from batch - states_policy = batch.get_states(policy=True) states = batch.get_states(policy=False) + states_policy = batch.get_states(policy=True) actions = batch.get_actions() - parents_policy = batch.get_parents(policy=True) parents = batch.get_parents(policy=False) - traj_indices = batch.get_trajectory_indices(consecutive=True) + parents_policy = batch.get_parents(policy=True) + rewards_states = batch.get_rewards(do_non_terminating=True) + rewards_parents = batch.get_rewards_parents() done = batch.get_done() - masks_b = batch.get_masks_backward() - policy_output_b = self.backward_policy(states_policy) - logprobs_bkw = self.env.get_logprobs( - policy_output_b, actions, masks_b, states, is_backward=True - ) + # Get logprobs masks_f = batch.get_masks_forward(of_parents=True) policy_output_f = self.forward_policy(parents_policy) - logprobs_fwd = self.env.get_logprobs( + logprobs_f = self.env.get_logprobs( policy_output_f, actions, masks_f, parents, is_backward=False ) + masks_b = batch.get_masks_backward() + policy_output_b = self.backward_policy(states_policy) + logprobs_b = self.env.get_logprobs( + policy_output_b, actions, masks_b, states, is_backward=True + ) - states_log_flflow = self.state_flow(states_policy) - # forward-looking flow is 1 in the terminal states - states_log_flflow[done.eq(1)] = 0.0 - # Can be optimised by reusing states_log_flflow and batch.get_parent_indices - parents_log_flflow = self.state_flow(parents_policy) - - rewards_states = batch.get_rewards(do_non_terminating=True) - rewards_parents = batch.get_rewards_parents() - energies_states = -torch.log(rewards_states) - energies_parents = -torch.log(rewards_parents) - - per_node_loss = ( - parents_log_flflow - - states_log_flflow - + logprobs_fwd - - logprobs_bkw - + energies_states - - energies_parents + # Get FL logflows + logflflows_states = self.state_flow(states_policy) + # Log FL flow of terminal states is 0 (eq. 9 of paper) + logflflows_states[done.eq(1)] = 0.0 + # TODO: Optimise by reusing logflows_states and batch.get_parent_indices + logflflows_parents = self.state_flow(parents_policy) + + # Get energies transitions + energies_transitions = torch.log(rewards_parents) - torch.log(rewards_states) + + # Forward-looking loss + loss_all = ( + logflflows_parents + - logflflows_states + + logprobs_f + - logprobs_b + + energies_transitions ).pow(2) - - term_loss = per_node_loss[done].mean() - nonterm_loss = per_node_loss[~done].mean() loss = per_node_loss.mean() - + loss_terminating = per_node_loss[done].mean() + loss_intermediate = per_node_loss[~done].mean() return loss, term_loss, nonterm_loss @torch.no_grad() From a777e9994f397a2b6aea4854fbcb5aef031a4905 Mon Sep 17 00:00:00 2001 From: alexhernandezgarcia Date: Mon, 27 Nov 2023 19:33:11 -0500 Subject: [PATCH 29/34] Fixes (cherry-pick) --- gflownet/gflownet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index 880189cdd..6acc0ddfe 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -808,10 +808,10 @@ def forwardlooking_loss(self, it, batch): - logprobs_b + energies_transitions ).pow(2) - loss = per_node_loss.mean() - loss_terminating = per_node_loss[done].mean() - loss_intermediate = per_node_loss[~done].mean() - return loss, term_loss, nonterm_loss + loss = loss_all.mean() + loss_terminating = loss_all[done].mean() + loss_intermediate = loss_all[~done].mean() + return loss, loss_terminating, loss_intermediate @torch.no_grad() def estimate_logprobs_data( From 08ee25b54aaa12215f155d1824ea39bd93ab1d26 Mon Sep 17 00:00:00 2001 From: Alexandra Date: Wed, 29 Nov 2023 11:57:48 -0500 Subject: [PATCH 30/34] if smth is False -> if not smth Co-authored-by: Alex --- gflownet/envs/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index e1cae5b86..118fe728f 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -786,7 +786,7 @@ def reward(self, state=None, done=None, do_non_terminating=False): """ state = self._get_state(state) done = self._get_done(done) - if done is False and do_non_terminating is False: + if not done and not do_non_terminating: return tfloat(0.0, float_type=self.float, device=self.device) return self.proxy2reward(self.proxy(self.state2proxy(state))[0]) From 0007bcc4314a3c6980acbb5cacc2af28be753142 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 7 Dec 2023 17:48:50 -0500 Subject: [PATCH 31/34] Assert that env is not done for parents. --- tests/gflownet/utils/test_batch.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/gflownet/utils/test_batch.py b/tests/gflownet/utils/test_batch.py index 2ccd85a58..ca64ac85c 100644 --- a/tests/gflownet/utils/test_batch.py +++ b/tests/gflownet/utils/test_batch.py @@ -1462,7 +1462,7 @@ def test__get_rewards_parents_multiple_env_returns_expected_non_terminating( # Make step env by env (different to GFN Agent) to have full control for env in envs: parent = copy(env.state) - done_parent = env.done + assert env.done is False # Sample random action state, action, valid = env.step_random() @@ -1471,7 +1471,7 @@ def test__get_rewards_parents_multiple_env_returns_expected_non_terminating( actions_iter.append(action) valids_iter.append(valid) rewards_parents.append( - env.reward(state=parent, done=done_parent, do_non_terminating=True) + env.reward(state=parent, done=False, do_non_terminating=True) ) rewards.append(env.reward(do_non_terminating=True)) # Add all envs, actions and valids to batch From 4390b1781682b399851ea936927f9f3a93ba70c4 Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 7 Dec 2023 18:09:51 -0500 Subject: [PATCH 32/34] Add section about losses in README --- README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/README.md b/README.md index 796f8dc3b..a6e8ce208 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,15 @@ The above command will overwrite the `env` and `proxy` default configuration wit Hydra configuration is hierarchical. For instance, a handy variable to change while debugging our code is to avoid logging to wandb. You can do this by setting `logger.do.online=False`. +## GFlowNet loss functions + +Currently, the implementation includes the following GFlowNet losses: + +- [Flow-matching (FM)](https://arxiv.org/abs/2106.04399): `gflownet=flowmatch` +- [Trajectory balance (TB)](https://arxiv.org/abs/2201.13259): `gflownet=trajectorybalance` +- [Detailed balance (DB)](https://arxiv.org/abs/2201.13259): `gflownet=detailedbalance` +- [Forward-looking (FL)](https://arxiv.org/abs/2302.01687): `gflownet=forwardlooking` + ## Logging to wandb The repository supports logging of train and evaluation metrics to [wandb.ai](https://wandb.ai), but it is disabled by default. In order to enable it, set the configuration variable `logger.do.online` to `True`. From d3b0126cceeebfac9cf493d10cadaa85a34d9e2e Mon Sep 17 00:00:00 2001 From: Alex Date: Thu, 7 Dec 2023 23:14:26 -0500 Subject: [PATCH 33/34] Add DB and FL runs to sanity checks. --- mila/dev/sanity_check_runs.yaml | 48 +++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/mila/dev/sanity_check_runs.yaml b/mila/dev/sanity_check_runs.yaml index 834d329ef..89427c2b8 100644 --- a/mila/dev/sanity_check_runs.yaml +++ b/mila/dev/sanity_check_runs.yaml @@ -24,6 +24,20 @@ jobs: __value__: grid length: 10 gflownet: trajectorybalance + - slurm: + job_name: sanity-grid-db + script: + env: + __value__: grid + length: 10 + gflownet: detailedbalance + - slurm: + job_name: sanity-grid-fl + script: + env: + __value__: grid + length: 10 + gflownet: forwardlooking # Tetris - slurm: job_name: sanity-tetris-fm @@ -43,6 +57,40 @@ jobs: height: 10 gflownet: trajectorybalance proxy: tetris + # Mini-Tetris + - slurm: + job_name: sanity-mintetris-fm + script: + env: + __value__: tetris + width: 3 + height: 10 + pieces: ["J", "L", "S", "Z"] + allow_eos_before_full: True + gflownet: flowmatch + proxy: tetris + - slurm: + job_name: sanity-mintetris-tb + script: + env: + __value__: tetris + width: 3 + height: 10 + pieces: ["J", "L", "S", "Z"] + allow_eos_before_full: True + gflownet: trajectorybalance + proxy: tetris + - slurm: + job_name: sanity-mintetris-fl + script: + env: + __value__: tetris + width: 3 + height: 10 + pieces: ["J", "L", "S", "Z"] + allow_eos_before_full: True + gflownet: forwardlooking + proxy: tetris # Ctorus - slurm: job_name: sanity-ctorus From fe37c34ebb89f4dec76f10d0405d69306f5a1f2d Mon Sep 17 00:00:00 2001 From: Alex Date: Sat, 9 Dec 2023 13:02:28 -0500 Subject: [PATCH 34/34] Make reward function boltzmann in mini-tetris sanity check experiments --- config/user/alex.yaml | 2 +- mila/dev/sanity_check_runs.yaml | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/config/user/alex.yaml b/config/user/alex.yaml index 1d308a263..59a7db7ab 100644 --- a/config/user/alex.yaml +++ b/config/user/alex.yaml @@ -1,5 +1,5 @@ logdir: - root: /network/scratch/h/hernanga/logs/gflownet + root: /home/alex/logs/gflownet data: root: /home/mila/h/hernanga/gflownet/data alanine_dipeptide: /home/mila/h/hernanga/gflownet/data/alanine_dipeptide_conformers_1.npy diff --git a/mila/dev/sanity_check_runs.yaml b/mila/dev/sanity_check_runs.yaml index 89427c2b8..443c235ed 100644 --- a/mila/dev/sanity_check_runs.yaml +++ b/mila/dev/sanity_check_runs.yaml @@ -67,6 +67,7 @@ jobs: height: 10 pieces: ["J", "L", "S", "Z"] allow_eos_before_full: True + reward_func: boltzmann gflownet: flowmatch proxy: tetris - slurm: @@ -78,6 +79,7 @@ jobs: height: 10 pieces: ["J", "L", "S", "Z"] allow_eos_before_full: True + reward_func: boltzmann gflownet: trajectorybalance proxy: tetris - slurm: @@ -89,6 +91,7 @@ jobs: height: 10 pieces: ["J", "L", "S", "Z"] allow_eos_before_full: True + reward_func: boltzmann gflownet: forwardlooking proxy: tetris # Ctorus