diff --git a/README.md b/README.md index ba2171ca4..24a6b7944 100644 --- a/README.md +++ b/README.md @@ -47,6 +47,15 @@ The above command will overwrite the `env` and `proxy` default configuration wit Hydra configuration is hierarchical. For instance, a handy variable to change while debugging our code is to avoid logging to wandb. You can do this by setting `logger.do.online=False`. +## GFlowNet loss functions + +Currently, the implementation includes the following GFlowNet losses: + +- [Flow-matching (FM)](https://arxiv.org/abs/2106.04399): `gflownet=flowmatch` +- [Trajectory balance (TB)](https://arxiv.org/abs/2201.13259): `gflownet=trajectorybalance` +- [Detailed balance (DB)](https://arxiv.org/abs/2201.13259): `gflownet=detailedbalance` +- [Forward-looking (FL)](https://arxiv.org/abs/2302.01687): `gflownet=forwardlooking` + ## Logging to wandb The repository supports logging of train and evaluation metrics to [wandb.ai](https://wandb.ai), but it is disabled by default. In order to enable it, set the configuration variable `logger.do.online` to `True`. diff --git a/config/experiments/simple_tetris.yaml b/config/experiments/simple_tetris.yaml new file mode 100644 index 000000000..716169dae --- /dev/null +++ b/config/experiments/simple_tetris.yaml @@ -0,0 +1,47 @@ +# @package _global_ + +defaults: + - override /env: tetris + - override /gflownet: trajectorybalance + - override /policy: mlp + - override /proxy: tetris + - override /logger: wandb + +env: + reward_func: boltzmann + reward_beta: 10.0 + width: 4 + height: 4 + pieces: ["I", "O", "J", "L", "T"] + rotations: [0, 90, 180, 270] + buffer: + # replay_capacity: 0 + test: + type: random + output_csv: simple_tetris_val.csv + output_pkl: simple_tetris_val.pkl + n: 100 + +gflownet: + random_action_prob: 0.3 + optimizer: + n_train_steps: 10000 + lr_z_mult: 100 + lr: 0.0001 + +policy: + forward: + type: mlp + n_hid: 128 + n_layers: 5 + + backward: + shared_weights: True + checkpoint: null + reload_ckpt: False + +device: cpu +logger: + do: + online: True + project_name: simple_tetris \ No newline at end of file diff --git a/config/gflownet/detailedbalance.yaml b/config/gflownet/detailedbalance.yaml new file mode 100644 index 000000000..073ae99f4 --- /dev/null +++ b/config/gflownet/detailedbalance.yaml @@ -0,0 +1,9 @@ +defaults: + - gflownet + - state_flow: mlp + +optimizer: + loss: detailedbalance + lr: 0.0001 + lr_decay_period: 1000000 + lr_decay_gamma: 0.5 diff --git a/config/gflownet/forwardlooking.yaml b/config/gflownet/forwardlooking.yaml new file mode 100644 index 000000000..c2c641719 --- /dev/null +++ b/config/gflownet/forwardlooking.yaml @@ -0,0 +1,9 @@ +defaults: + - gflownet + - state_flow: mlp + +optimizer: + loss: forwardlooking + lr: 0.0001 + lr_decay_period: 1000000 + lr_decay_gamma: 0.5 diff --git a/config/gflownet/gflownet.yaml b/config/gflownet/gflownet.yaml index 22dd6dd10..c33e52eaf 100644 --- a/config/gflownet/gflownet.yaml +++ b/config/gflownet/gflownet.yaml @@ -34,6 +34,8 @@ optimizer: # From original implementation bootstrap_tau: 0.0 clip_grad_norm: 0.0 +# State flow modelling +state_flow: null # If True, compute rewards in batches batch_reward: True # Force zero probability of sampling invalid actions diff --git a/config/gflownet/state_flow/mlp.yaml b/config/gflownet/state_flow/mlp.yaml new file mode 100644 index 000000000..6ccf772ee --- /dev/null +++ b/config/gflownet/state_flow/mlp.yaml @@ -0,0 +1,9 @@ +_target_: gflownet.policy.state_flow.StateFlow + +config: + type: mlp + n_hid: 128 + n_layers: 2 + checkpoint: null + reload_ckpt: False + shared_weights: False \ No newline at end of file diff --git a/config/policy/mlp_detailedbalance.yaml b/config/policy/mlp_detailedbalance.yaml new file mode 100644 index 000000000..41f43231e --- /dev/null +++ b/config/policy/mlp_detailedbalance.yaml @@ -0,0 +1,7 @@ +defaults: + - mlp + +backward: + shared_weights: True + checkpoint: null + reload_ckpt: False diff --git a/config/policy/mlp_forwardlooking.yaml b/config/policy/mlp_forwardlooking.yaml new file mode 100644 index 000000000..41f43231e --- /dev/null +++ b/config/policy/mlp_forwardlooking.yaml @@ -0,0 +1,7 @@ +defaults: + - mlp + +backward: + shared_weights: True + checkpoint: null + reload_ckpt: False diff --git a/config/user/alex.yaml b/config/user/alex.yaml index 1d308a263..59a7db7ab 100644 --- a/config/user/alex.yaml +++ b/config/user/alex.yaml @@ -1,5 +1,5 @@ logdir: - root: /network/scratch/h/hernanga/logs/gflownet + root: /home/alex/logs/gflownet data: root: /home/mila/h/hernanga/gflownet/data alanine_dipeptide: /home/mila/h/hernanga/gflownet/data/alanine_dipeptide_conformers_1.npy diff --git a/gflownet/envs/base.py b/gflownet/envs/base.py index ab68825a3..4cd092cad 100644 --- a/gflownet/envs/base.py +++ b/gflownet/envs/base.py @@ -762,13 +762,13 @@ def traj2readable(self, traj=None): """ return str(traj).replace("(", "[").replace(")", "]").replace(",", "") - def reward(self, state=None, done=None): + def reward(self, state=None, done=None, do_non_terminating=False): """ Computes the reward of a state """ state = self._get_state(state) done = self._get_done(done) - if done is False: + if not done and not do_non_terminating: return tfloat(0.0, float_type=self.float, device=self.device) return self.proxy2reward( self.proxy(torch.unsqueeze(self.state2proxy(state), dim=0))[0] diff --git a/gflownet/gflownet.py b/gflownet/gflownet.py index fdbea6e76..bd94681b4 100644 --- a/gflownet/gflownet.py +++ b/gflownet/gflownet.py @@ -50,6 +50,7 @@ def __init__( logger, num_empirical_loss, oracle, + state_flow=None, active_learning=False, sample_only=False, replay_sampling="permutation", @@ -79,6 +80,12 @@ def __init__( elif optimizer.loss in ["trajectorybalance", "tb"]: self.loss = "trajectorybalance" self.logZ = nn.Parameter(torch.ones(optimizer.z_dim) * 150.0 / 64) + elif optimizer.loss in ["detailedbalance", "db"]: + self.loss = "detailedbalance" + self.logZ = None + elif optimizer.loss in ["forwardlooking", "fl"]: + self.loss = "forwardlooking" + self.logZ = None else: print("Unkown loss. Using flowmatch as default") self.loss = "flowmatch" @@ -121,7 +128,8 @@ def __init__( print(f"\tStd score: {self.buffer.test['energies'].std()}") print(f"\tMin score: {self.buffer.test['energies'].min()}") print(f"\tMax score: {self.buffer.test['energies'].max()}") - # Policy models + + # Models self.forward_policy = forward_policy if self.forward_policy.checkpoint is not None: self.logger.set_forward_policy_ckpt_path(self.forward_policy.checkpoint) @@ -133,6 +141,7 @@ def __init__( print("Reloaded GFN forward policy model Checkpoint") else: self.logger.set_forward_policy_ckpt_path(None) + self.backward_policy = backward_policy self.logger.set_backward_policy_ckpt_path(None) if self.backward_policy.checkpoint is not None: @@ -145,6 +154,14 @@ def __init__( print("Reloaded GFN backward policy model Checkpoint") else: self.logger.set_backward_policy_ckpt_path(None) + + self.state_flow = state_flow + if self.state_flow is not None and self.state_flow.checkpoint is not None: + self.logger.set_state_flow_ckpt_path(self.state_flow.checkpoint) + # TODO: add the logic and conditions to reload a model + else: + self.logger.set_state_flow_ckpt_path(None) + # Optimizer if self.forward_policy.is_model: self.target = copy.deepcopy(self.forward_policy.model) @@ -178,14 +195,16 @@ def __init__( self.nll_tt = 0.0 def parameters(self): - if self.backward_policy.is_model is False: - return list(self.forward_policy.model.parameters()) - elif self.loss == "trajectorybalance": - return list(self.forward_policy.model.parameters()) + list( - self.backward_policy.model.parameters() - ) - else: - raise ValueError("Backward Policy cannot be a nn in flowmatch.") + parameters = list(self.forward_policy.model.parameters()) + if self.backward_policy.is_model: + if self.loss == "flowmatch": + raise ValueError("Backward Policy cannot be a model in flowmatch.") + parameters += list(self.backward_policy.model.parameters()) + if self.state_flow is not None: + if self.loss not in ["detailedbalance", "forwardlooking"]: + raise ValueError(f"State flow cannot be trained with {self.loss} loss.") + parameters += list(self.state_flow.model.parameters()) + return parameters def sample_actions( self, @@ -662,6 +681,137 @@ def trajectorybalance_loss(self, it, batch): ) return loss, loss, loss + def detailedbalance_loss(self, it, batch): + """ + Computes the Detailed Balance GFlowNet loss of a batch + Reference : https://arxiv.org/pdf/2201.13259.pdf (eq 11) + + Args + ---- + it : int + Iteration + + batch : Batch + A batch of data, containing all the states in the trajectories. + + + Returns + ------- + loss : float + + term_loss : float + Loss of the terminal nodes only + + nonterm_loss : float + Loss of the intermediate nodes only + """ + + assert batch.is_valid() + # Get necessary tensors from batch + states = batch.get_states(policy=False) + states_policy = batch.get_states(policy=True) + actions = batch.get_actions() + parents = batch.get_parents(policy=False) + parents_policy = batch.get_parents(policy=True) + done = batch.get_done() + rewards = batch.get_terminating_rewards(sort_by="insertion") + + # Get logprobs + masks_f = batch.get_masks_forward(of_parents=True) + policy_output_f = self.forward_policy(parents_policy) + logprobs_f = self.env.get_logprobs( + policy_output_f, actions, masks_f, parents, is_backward=False + ) + masks_b = batch.get_masks_backward() + policy_output_b = self.backward_policy(states_policy) + logprobs_b = self.env.get_logprobs( + policy_output_b, actions, masks_b, states, is_backward=True + ) + + # Get logflows + logflows_states = self.state_flow(states_policy) + logflows_states[done.eq(1)] = torch.log(rewards) + # TODO: Optimise by reusing logflows_states and batch.get_parent_indices + logflows_parents = self.state_flow(parents_policy) + + # Detailed balance loss + loss_all = (logflows_parents + logprobs_f - logflows_states - logprobs_b).pow(2) + loss = loss_all.mean() + loss_terminating = loss_all[done].mean() + loss_intermediate = loss_all[~done].mean() + return loss, loss_terminating, loss_intermediate + + def forwardlooking_loss(self, it, batch): + """ + Computes the Forward-Looking GFlowNet loss of a batch + Reference : https://arxiv.org/pdf/2302.01687.pdf + + Args + ---- + it : int + Iteration + + batch : Batch + A batch of data, containing all the states in the trajectories. + + + Returns + ------- + loss : float + + term_loss : float + Loss of the terminal nodes only + + nonterm_loss : float + Loss of the intermediate nodes only + """ + + assert batch.is_valid() + # Get necessary tensors from batch + states = batch.get_states(policy=False) + states_policy = batch.get_states(policy=True) + actions = batch.get_actions() + parents = batch.get_parents(policy=False) + parents_policy = batch.get_parents(policy=True) + rewards_states = batch.get_rewards(do_non_terminating=True) + rewards_parents = batch.get_rewards_parents() + done = batch.get_done() + + # Get logprobs + masks_f = batch.get_masks_forward(of_parents=True) + policy_output_f = self.forward_policy(parents_policy) + logprobs_f = self.env.get_logprobs( + policy_output_f, actions, masks_f, parents, is_backward=False + ) + masks_b = batch.get_masks_backward() + policy_output_b = self.backward_policy(states_policy) + logprobs_b = self.env.get_logprobs( + policy_output_b, actions, masks_b, states, is_backward=True + ) + + # Get FL logflows + logflflows_states = self.state_flow(states_policy) + # Log FL flow of terminal states is 0 (eq. 9 of paper) + logflflows_states[done.eq(1)] = 0.0 + # TODO: Optimise by reusing logflows_states and batch.get_parent_indices + logflflows_parents = self.state_flow(parents_policy) + + # Get energies transitions + energies_transitions = torch.log(rewards_parents) - torch.log(rewards_states) + + # Forward-looking loss + loss_all = ( + logflflows_parents + - logflflows_states + + logprobs_f + - logprobs_b + + energies_transitions + ).pow(2) + loss = loss_all.mean() + loss_terminating = loss_all[done].mean() + loss_intermediate = loss_all[~done].mean() + return loss, loss_terminating, loss_intermediate + @torch.no_grad() def estimate_logprobs_data( self, @@ -874,6 +1024,10 @@ def train(self): losses = self.trajectorybalance_loss( it * self.ttsr + j, batch ) # returns (opt loss, *metrics) + elif self.loss == "detailedbalance": + losses = self.detailedbalance_loss(it * self.ttsr + j, batch) + elif self.loss == "forwardlooking": + losses = self.forwardlooking_loss(it * self.ttsr + j, batch) else: print("Unknown loss!") # TODO: deal with this in a better way @@ -937,7 +1091,9 @@ def train(self): times.update({"log": t1_log - t0_log}) # Save intermediate models t0_model = time.time() - self.logger.save_models(self.forward_policy, self.backward_policy, step=it) + self.logger.save_models( + self.forward_policy, self.backward_policy, self.state_flow, step=it + ) t1_model = time.time() times.update({"save_interim_model": t1_model - t0_model}) @@ -966,7 +1122,9 @@ def train(self): self.logger.log_time(times, use_context=self.use_context) # Save final model - self.logger.save_models(self.forward_policy, self.backward_policy, final=True) + self.logger.save_models( + self.forward_policy, self.backward_policy, self.state_flow, final=True + ) # Close logger if self.use_context is False: self.logger.end() diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index 50b625e09..34968a6c1 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -1,3 +1,5 @@ +from abc import ABC, abstractmethod + import torch from omegaconf import OmegaConf from torch import nn @@ -5,7 +7,7 @@ from gflownet.utils.common import set_device, set_float_precision -class Policy: +class ModelBase(ABC): def __init__(self, config, env, device, float_precision, base=None): # Device and float precision self.device = set_device(device) @@ -19,33 +21,17 @@ def __init__(self, config, env, device, float_precision, base=None): self.base = base self.parse_config(config) - self.instantiate() def parse_config(self, config): # If config is null, default to uniform if config is None: config = OmegaConf.create() config.type = "uniform" - if "checkpoint" in config: - self.checkpoint = config.checkpoint - else: - self.checkpoint = None - if "shared_weights" in config: - self.shared_weights = config.shared_weights - else: - self.shared_weights = False - if "n_hid" in config: - self.n_hid = config.n_hid - else: - self.n_hid = None - if "n_layers" in config: - self.n_layers = config.n_layers - else: - self.n_layers = None - if "tail" in config: - self.tail = config.tail - else: - self.tail = [] + self.checkpoint = config.get("checkpoint", None) + self.shared_weights = config.get("shared_weights", False) + self.n_hid = config.get("n_hid", None) + self.n_layers = config.get("n_layers", None) + self.tail = config.get("tail", []) if "type" in config: self.type = config.type elif self.shared_weights: @@ -53,18 +39,9 @@ def parse_config(self, config): else: raise "Policy type must be defined if shared_weights is False" + @abstractmethod def instantiate(self): - if self.type == "fixed": - self.model = self.fixed_distribution - self.is_model = False - elif self.type == "uniform": - self.model = self.uniform_distribution - self.is_model = False - elif self.type == "mlp": - self.model = self.make_mlp(nn.LeakyReLU()).to(self.device) - self.is_model = True - else: - raise "Policy model type not defined" + pass def __call__(self, states): return self.model(states) @@ -114,6 +91,26 @@ def make_mlp(self, activation): "Base Model must be provided when shared_weights is set to True" ) + +class Policy(ModelBase): + def __init__(self, config, env, device, float_precision, base=None): + super().__init__(config, env, device, float_precision, base) + + self.instantiate() + + def instantiate(self): + if self.type == "fixed": + self.model = self.fixed_distribution + self.is_model = False + elif self.type == "uniform": + self.model = self.uniform_distribution + self.is_model = False + elif self.type == "mlp": + self.model = self.make_mlp(nn.LeakyReLU()).to(self.device) + self.is_model = True + else: + raise "Policy model type not defined" + def fixed_distribution(self, states): """ Returns the fixed distribution specified by the environment. diff --git a/gflownet/policy/state_flow.py b/gflownet/policy/state_flow.py new file mode 100644 index 000000000..51f6378b3 --- /dev/null +++ b/gflownet/policy/state_flow.py @@ -0,0 +1,32 @@ +import torch +from torch import nn + +from gflownet.policy.base import ModelBase +from gflownet.utils.common import set_device, set_float_precision + + +class StateFlow(ModelBase): + """ + Takes state in the policy format and predicts its flow (a scalar) + """ + + def __init__(self, config, env, device, float_precision, base=None): + super().__init__(config, env, device, float_precision, base) + # Output dimension + self.output_dim = 1 + + # Instantiate neural network + self.instantiate() + + def instantiate(self): + if self.type == "mlp": + self.model = self.make_mlp(nn.LeakyReLU()).to(self.device) + self.is_model = True + else: + raise "StateFlow model type not defined" + + def __call__(self, states): + """ + Returns a tensor of the state flows of the shape (batch_size, ) + """ + return super().__call__(states).squeeze() diff --git a/gflownet/utils/batch.py b/gflownet/utils/batch.py index 9ab3dc24d..a58c1a7e7 100644 --- a/gflownet/utils/batch.py +++ b/gflownet/utils/batch.py @@ -88,12 +88,14 @@ def __init__( self.states_policy = None self.parents_policy = None # Flags for available items - self.parents_available = True + self.parents_available = False self.parents_policy_available = False self.parents_all_available = False self.masks_forward_available = False self.masks_backward_available = False self.rewards_available = False + self.rewards_parents_available = False + self.rewards_source_available = False def __len__(self): return self.size @@ -516,33 +518,63 @@ def get_parents( else: return self.parents + def get_parents_indices(self): + """ + Returns the indices of the parents of the states in the batch. + + Each i-th item in the returned list contains the index in self.states that + contains the parent of self.states[i], if it is present there. If a parent + is not present in self.states (because it is the source), the index is -1. + + Returns + ------- + self.parents_indices + The indices in self.states of the parents of self.states. + """ + if self.parents_available is False: + self._compute_parents() + return self.parents_indices + def _compute_parents(self): """ - Obtains the parent (single parent for each state) of all states in the batch. + Obtains the parent (single parent for each state) of all states in the batch + and its index. + The parents are computed, obtaining all necessary components, if they are not readily available. Missing components and newly computed components are added - to the batch (self.component is set). The following variable is stored: + to the batch (self.component is set). The following variables are stored: - self.parents: the parent of each state in the batch. It will be the same type as self.states (list of lists or tensor) Length: n_states Shape: [n_states, state_dims] + - self.parents_indices: the position of each parent in self.states tensor. If a + parent is not present in self.states (i.e. it is source), the corresponding + index is -1. self.parents_available is set to True. """ self.parents = [] + self.parents_indices = [] indices = [] # Iterate over the trajectories to obtain the parents from the states for traj_idx, batch_indices in self.trajectories.items(): # parent is source self.parents.append(self.envs[traj_idx].source) + # there's no source state in the batch + self.parents_indices.append(-1) # parent is not source # TODO: check if tensor and sort without iter self.parents.extend([self.states[idx] for idx in batch_indices[:-1]]) + self.parents_indices.extend(batch_indices[:-1]) indices.extend(batch_indices) # Sort parents list in the same order as states # TODO: check if tensor and sort without iter self.parents = [self.parents[indices.index(idx)] for idx in range(len(self))] + self.parents_indices = tlong( + [self.parents_indices[indices.index(idx)] for idx in range(len(self))], + device=self.device, + ) self.parents_available = True # TODO: consider converting directly from self.parents @@ -808,7 +840,9 @@ def _compute_masks_backward(self): self.masks_backward_available = True def get_rewards( - self, force_recompute: Optional[bool] = False + self, + force_recompute: Optional[bool] = False, + do_non_terminating: Optional[bool] = False, ) -> TensorType["n_states"]: """ Returns the rewards of all states in the batch (including not done). @@ -817,29 +851,100 @@ def get_rewards( ---- force_recompute : bool If True, the rewards are recomputed even if they are available. + + do_non_terminating : bool + If True, compute the rewards of the non-terminating states instead of + assigning reward 0. """ if self.rewards_available is False or force_recompute is True: - self._compute_rewards() + self._compute_rewards(do_non_terminating) return self.rewards - def _compute_rewards(self): + def _compute_rewards(self, do_non_terminating: Optional[bool] = False): """ Computes rewards for all self.states by first converting the states into proxy - format. + format. The result is stored in self.rewards as a torch.tensor + + Args + ---- + do_non_terminating : bool + If True, compute the rewards of the non-terminating states instead of + assigning reward 0. + """ + + if do_non_terminating: + self.rewards = self.env.proxy2reward(self.env.proxy(self.states2proxy())) + else: + self.rewards = torch.zeros(len(self), dtype=self.float, device=self.device) + done = self.get_done() + if len(done) > 0: + states_proxy_done = self.get_terminating_states(proxy=True) + self.rewards[done] = self.env.proxy2reward( + self.env.proxy(states_proxy_done) + ) + self.rewards_available = True + + def get_rewards_parents(self) -> TensorType["n_states"]: + """ + Returns the rewards of all parents in the batch. Returns ------- - rewards: torch.tensor - Tensor of rewards. - """ - states_proxy_done = self.get_terminating_states(proxy=True) - self.rewards = torch.zeros(len(self), dtype=self.float, device=self.device) - done = self.get_done() - if len(done) > 0: - self.rewards[done] = self.env.proxy2reward( - self.env.proxy(states_proxy_done) - ) - self.rewards_available = True + self.rewards_parents + A tensor containing the rewards of the parents of self.states. + """ + if not self.rewards_parents_available: + self._compute_rewards_parents() + return self.rewards_parents + + def _compute_rewards_parents(self): + """ + Computes the rewards of self.parents by reusing the rewards of the states + (self.rewards). + + Stores the result in self.rewards_parents. + """ + # TODO: this may return zero rewards for all parents if before + # rewards for states were computed with do_non_terminating=False + state_rewards = self.get_rewards(do_non_terminating=True) + self.rewards_parents = torch.zeros_like(state_rewards) + parent_indices = self.get_parents_indices() + parent_is_source = parent_indices == -1 + self.rewards_parents[~parent_is_source] = self.rewards[ + parent_indices[~parent_is_source] + ] + rewards_source = self.get_rewards_source() + self.rewards_parents[parent_is_source] = rewards_source[parent_is_source] + self.rewards_parents_available = True + + def get_rewards_source(self) -> TensorType["n_states"]: + """ + Returns rewards of the corresponding source states for each state in the batch. + + Returns + ------- + self.rewards_source + A tensor containing the rewards the source states. + """ + if not self.rewards_source_available: + self._compute_rewards_source() + return self.rewards_source + + def _compute_rewards_source(self): + """ + Computes a tensor of length len(self.states) with the rewards of the + corresponding source states. + + Stores the result in self.rewards_source. + """ + # This will not work if source is randomised + if not self.conditional: + source_proxy = self.env.state2proxy(self.env.source) + reward_source = self.env.proxy2reward(self.env.proxy(source_proxy)) + self.rewards_source = reward_source.expand(len(self)) + else: + raise NotImplementedError + self.rewards_source_available = True def get_terminating_states( self, diff --git a/gflownet/utils/logger.py b/gflownet/utils/logger.py index 50356b377..2e796410c 100644 --- a/gflownet/utils/logger.py +++ b/gflownet/utils/logger.py @@ -145,6 +145,12 @@ def set_backward_policy_ckpt_path(self, ckpt_id: str = None): else: self.pb_ckpt_path = self.ckpts_dir / f"{ckpt_id}_" + def set_state_flow_ckpt_path(self, ckpt_id: str = None): + if ckpt_id is None: + self.sf_ckpt_path = None + else: + self.sf_ckpt_path = self.ckpts_dir / f"{ckpt_id}_" + def progressbar_update( self, pbar, losses, rewards, jsd, step, use_context=True, n_mean=100 ): @@ -357,7 +363,7 @@ def log_test_metrics( ) def save_models( - self, forward_policy, backward_policy, step: int = 1e9, final=False + self, forward_policy, backward_policy, state_flow, step: int = 1e9, final=False ): if self.do_checkpoints(step) or final: if final: @@ -377,6 +383,11 @@ def save_models( path = self.pb_ckpt_path.parent / stem torch.save(backward_policy.model.state_dict(), path) + if state_flow is not None and self.sf_ckpt_path is not None: + stem = self.sf_ckpt_path.stem + self.context + ckpt_id + ".ckpt" + path = self.sf_ckpt_path.parent / stem + torch.save(state_flow.model.state_dict(), path) + def log_time(self, times: dict, use_context: bool): if self.do.times: times = {"time_{}".format(k): v for k, v in times.items()} diff --git a/main.py b/main.py index 62a5d064a..25171afc6 100644 --- a/main.py +++ b/main.py @@ -60,7 +60,18 @@ def main(config): float_precision=config.float_precision, base=forward_policy, ) - + # State flow + if config.gflownet.state_flow is not None: + state_flow = hydra.utils.instantiate( + config.gflownet.state_flow, + env=env, + device=config.device, + float_precision=config.float_precision, + base=forward_policy, + ) + else: + state_flow = None + # GFlowNet Agent gflownet = hydra.utils.instantiate( config.gflownet, device=config.device, @@ -68,9 +79,12 @@ def main(config): env=env, forward_policy=forward_policy, backward_policy=backward_policy, + state_flow=state_flow, buffer=config.env.buffer, logger=logger, ) + + # Train GFlowNet gflownet.train() # Sample from trained GFlowNet diff --git a/mila/dev/sanity_check_runs.yaml b/mila/dev/sanity_check_runs.yaml index 834d329ef..443c235ed 100644 --- a/mila/dev/sanity_check_runs.yaml +++ b/mila/dev/sanity_check_runs.yaml @@ -24,6 +24,20 @@ jobs: __value__: grid length: 10 gflownet: trajectorybalance + - slurm: + job_name: sanity-grid-db + script: + env: + __value__: grid + length: 10 + gflownet: detailedbalance + - slurm: + job_name: sanity-grid-fl + script: + env: + __value__: grid + length: 10 + gflownet: forwardlooking # Tetris - slurm: job_name: sanity-tetris-fm @@ -43,6 +57,43 @@ jobs: height: 10 gflownet: trajectorybalance proxy: tetris + # Mini-Tetris + - slurm: + job_name: sanity-mintetris-fm + script: + env: + __value__: tetris + width: 3 + height: 10 + pieces: ["J", "L", "S", "Z"] + allow_eos_before_full: True + reward_func: boltzmann + gflownet: flowmatch + proxy: tetris + - slurm: + job_name: sanity-mintetris-tb + script: + env: + __value__: tetris + width: 3 + height: 10 + pieces: ["J", "L", "S", "Z"] + allow_eos_before_full: True + reward_func: boltzmann + gflownet: trajectorybalance + proxy: tetris + - slurm: + job_name: sanity-mintetris-fl + script: + env: + __value__: tetris + width: 3 + height: 10 + pieces: ["J", "L", "S", "Z"] + allow_eos_before_full: True + reward_func: boltzmann + gflownet: forwardlooking + proxy: tetris # Ctorus - slurm: job_name: sanity-ctorus diff --git a/scripts/dav_mp20_stats.py b/scripts/dav_mp20_stats.py index 2b3e7ee5d..3df1c78c9 100644 --- a/scripts/dav_mp20_stats.py +++ b/scripts/dav_mp20_stats.py @@ -20,6 +20,7 @@ from collections import Counter from external.repos.ActiveLearningMaterials.dave.utils.loaders import make_loaders + from gflownet.proxy.crystals.dave import DAVE from gflownet.utils.common import load_gflow_net_from_run_path, resolve_path diff --git a/tests/gflownet/utils/test_batch.py b/tests/gflownet/utils/test_batch.py index 776bba324..2b3956038 100644 --- a/tests/gflownet/utils/test_batch.py +++ b/tests/gflownet/utils/test_batch.py @@ -72,6 +72,11 @@ def tetris_score(): return TetrisScore(device="cpu", float_precision=32, normalize=False) +@pytest.fixture() +def tetris_score_norm(): + return TetrisScore(device="cpu", float_precision=32, normalize=True) + + # @pytest.mark.skip(reason="skip while developping other tests") def test__len__returnszero_at_init(batch): assert len(batch) == 0 @@ -1192,3 +1197,167 @@ def test__make_indices_consecutive__multiplied_indices_become_consecutive( assert torch.equal( traj_indices_batch, tlong(traj_indices_consecutive, device=batch.device) ) + + +@pytest.mark.repeat(N_REPETITIONS) +@pytest.mark.parametrize( + "env, proxy", + [("grid2d", "corners"), ("tetris6x4", "tetris_score"), ("ctorus2d5l", "corners")], +) +# @pytest.mark.skip(reason="skip while developping other tests") +def test__get_rewards__single_env_returns_expected_non_terminating( + env, proxy, batch, request +): + env = request.getfixturevalue(env) + proxy = request.getfixturevalue(proxy) + env = env.reset() + env.proxy = proxy + env.setup_proxy() + batch.set_env(env) + + rewards = [] + while not env.done: + parent = env.state + # Sample random action + _, action, valid = env.step_random() + # Add to batch + batch.add_to_batch([env], [action], [valid]) + if valid: + rewards.append(env.reward(do_non_terminating=True)) + rewards_batch = batch.get_rewards(do_non_terminating=True) + rewards = torch.stack(rewards) + assert torch.equal( + rewards_batch, + tfloat(rewards, device=batch.device, float_type=batch.float), + ), (rewards, rewards_batch) + + +@pytest.mark.repeat(N_REPETITIONS) +# @pytest.mark.skip(reason="skip while developping other tests") +@pytest.mark.parametrize( + "env, proxy", + [("grid2d", "corners"), ("tetris6x4", "tetris_score_norm")], +) +def test__get_rewards_multiple_env_returns_expected_non_zero_non_terminating( + env, proxy, batch, request +): + batch_size = BATCH_SIZE + env_ref = request.getfixturevalue(env) + proxy = request.getfixturevalue(proxy) + env_ref = env_ref.reset() + env_ref.proxy = proxy + env_ref.setup_proxy() + env_ref.reward_func = "boltzmann" + + batch.set_env(env_ref) + + # Make list of envs + envs = [] + for idx in range(batch_size): + env_aux = env_ref.copy().reset(idx) + envs.append(env_aux) + + rewards = [] + proxy_values = [] + + # Iterate until envs is empty + while envs: + actions_iter = [] + valids_iter = [] + # Make step env by env (different to GFN Agent) to have full control + for env in envs: + parent = copy(env.state) + # Sample random action + state, action, valid = env.step_random() + if valid: + # Add to iter lists + actions_iter.append(action) + valids_iter.append(valid) + rewards.append(env.reward(do_non_terminating=True)) + proxy_values.append(env.proxy(env.state2proxy(env.state))[0]) + # Add all envs, actions and valids to batch + batch.add_to_batch(envs, actions_iter, valids_iter) + # Remove done envs + envs = [env for env in envs if not env.done] + + rewards_batch = batch.get_rewards(do_non_terminating=True) + rewards = torch.stack(rewards) + assert torch.equal( + rewards_batch, + tfloat(rewards, device=batch.device, float_type=batch.float), + ), (rewards, rewards_batch) + assert ~torch.any( + torch.isclose(rewards_batch, torch.zeros_like(rewards_batch)) + ), rewards_batch + + +@pytest.mark.repeat(N_REPETITIONS) +# @pytest.mark.skip(reason="skip while developping other tests") +@pytest.mark.parametrize( + "env, proxy", + [ + ("grid2d", "corners"), + ("tetris6x4", "tetris_score_norm"), + ("ctorus2d5l", "corners"), + ], +) +def test__get_rewards_parents_multiple_env_returns_expected_non_terminating( + env, proxy, batch, request +): + batch_size = BATCH_SIZE + env_ref = request.getfixturevalue(env) + proxy = request.getfixturevalue(proxy) + env_ref = env_ref.reset() + env_ref.proxy = proxy + env_ref.setup_proxy() + + batch.set_env(env_ref) + + # Make list of envs + envs = [] + for idx in range(batch_size): + env_aux = env_ref.copy().reset(idx) + envs.append(env_aux) + + rewards_parents = [] + rewards = [] + + # Iterate until envs is empty + while envs: + actions_iter = [] + valids_iter = [] + # Make step env by env (different to GFN Agent) to have full control + for env in envs: + parent = copy(env.state) + assert env.done is False + + # Sample random action + state, action, valid = env.step_random() + if valid: + # Add to iter lists + actions_iter.append(action) + valids_iter.append(valid) + rewards_parents.append( + env.reward(state=parent, done=False, do_non_terminating=True) + ) + rewards.append(env.reward(do_non_terminating=True)) + # Add all envs, actions and valids to batch + batch.add_to_batch(envs, actions_iter, valids_iter) + # Remove done envs + envs = [env for env in envs if not env.done] + + rewards_parents_batch = batch.get_rewards_parents() + rewards_parents = torch.stack(rewards_parents) + + rewards_batch = batch.get_rewards(do_non_terminating=True) + rewards = torch.stack(rewards) + + assert torch.equal( + rewards_parents_batch, + tfloat(rewards_parents, device=batch.device, float_type=batch.float), + ), (rewards_parents, rewards_parents_batch) + + assert torch.equal( + rewards_batch, + tfloat(rewards, device=batch.device, float_type=batch.float), + ), (rewards, rewards_batch)