From e65ad871c921bc960f7da1f14f137556295aedec Mon Sep 17 00:00:00 2001 From: Salem Date: Sun, 6 Aug 2023 15:30:17 -0400 Subject: [PATCH 01/21] Env not part of module anymore. Preprocessor is passed instead. is_backward becomes a required flag to create module. DiscretePolicyEstimator doesn't have specific sampling parameters as part of init anymore. --- src/gfn/modules.py | 100 ++++++++++++++++++++++++--------------------- 1 file changed, 53 insertions(+), 47 deletions(-) diff --git a/src/gfn/modules.py b/src/gfn/modules.py index 55f6361f..1dde4fa4 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -5,7 +5,7 @@ from torch.distributions import Categorical, Distribution from torchtyping import TensorType as TT -from gfn.env import DiscreteEnv, Env +from gfn.preprocessors import IdentityPreprocessor, Preprocessor from gfn.states import DiscreteStates, States from gfn.utils.distributions import UnsqueezedCategorical @@ -29,12 +29,14 @@ class GFNModule(ABC, nn.Module): Otherwise, one can overwrite and use the to_probability_distribution() method to directly output a probability distribution. - The preprocessor is also encapsulated in the estimator via the environment. + The preprocessor is also encapsulated in the estimator. These function estimators implement the `__call__` method, which takes `States` objects as inputs and calls the module on the preprocessed states. Attributes: - env: the environment. + preprocessor: Preprocessor object that transforms raw States objects to tensors + that can be used as input to the module. Optional, defaults to + `IdentityPreprocessor`. module: The module to use. If the module is a Tabular module (from `gfn.utils.modules`), then the environment preprocessor needs to be an `EnumPreprocessor`. @@ -44,19 +46,31 @@ class GFNModule(ABC, nn.Module): been verified. """ - def __init__(self, env: Env, module: nn.Module) -> None: + def __init__( + self, + module: nn.Module, + preprocessor: Preprocessor | None = None, + is_backward: bool = False, + ) -> None: """Initalize the FunctionEstimator with an environment and a module. Args: - env: the environment. module: The module to use. If the module is a Tabular module (from `gfn.utils.modules`), then the environment preprocessor needs to be an `EnumPreprocessor`. + preprocessor: Preprocessor object. + is_backward: Flags estimators of probability distributions over parents. """ nn.Module.__init__(self) - self.env = env self.module = module - self.preprocessor = env.preprocessor # TODO: passed explicitly? + if preprocessor is None: + assert hasattr(module, "input_dim"), ( + "Module needs to have an attribute `input_dim` specifying the input " + + "dimension, in order to use the default IdentityPreprocessor." + ) + preprocessor = IdentityPreprocessor(module.input_dim) + self.preprocessor = preprocessor self._output_dim_is_checked = False + self.is_backward = is_backward def forward(self, states: States) -> TT["batch_shape", "output_dim", float]: out = self.module(self.preprocessor(states)) @@ -88,9 +102,12 @@ def to_probability_distribution( self, states: States, module_output: TT["batch_shape", "output_dim", float], + *args, ) -> Distribution: """Transform the output of the module into a probability distribution. + The kwargs modify a base distribution, for example to encourage exploration. + Not all modules must implement this method, but it is required to define a policy from a module's outputs. See `DiscretePolicyEstimator` for an example using a categorical distribution, but note this can be done for all continuous @@ -105,7 +122,7 @@ def expected_output_dim(self) -> int: class DiscretePolicyEstimator(GFNModule): - r"""Container for forward and backward policy estimators. + r"""Container for forward and backward policy estimators for discrete environments. $s \mapsto (P_F(s' \mid s))_{s' \in Children(s)}$. @@ -113,11 +130,6 @@ class DiscretePolicyEstimator(GFNModule): $s \mapsto (P_B(s' \mid s))_{s' \in Parents(s)}$. - Note that while this class resembles LogEdgeFlowProbabilityEstimator, they have - different semantic meaning. With LogEdgeFlowEstimator, the module output is the log - of the flow from the parent to the child, while with DiscretePFEstimator, the - module output is arbitrary. - Attributes: temperature: scalar to divide the logits by before softmax. sf_bias: scalar to subtract from the exit action logit before dividing by @@ -127,60 +139,54 @@ class DiscretePolicyEstimator(GFNModule): def __init__( self, - env: Env, module: nn.Module, - forward: bool, - greedy_eps: float = 0.0, - temperature: float = 1.0, - sf_bias: float = 0.0, - epsilon: float = 0.0, + n_actions: int, + preprocessor: Preprocessor | None, + is_backward: bool = False, ): """Initializes a estimator for P_F for discrete environments. Args: - forward: if True, then this is a forward policy, else backward policy. - greedy_eps: if > 0 , then we go off policy using greedy epsilon exploration. - temperature: scalar to divide the logits by before softmax. Does nothing - if greedy_eps is 0. - sf_bias: scalar to subtract from the exit action logit before dividing by - temperature. Does nothing if greedy_eps is 0. - epsilon: with probability epsilon, a random action is chosen. Does nothing - if greedy_eps is 0. + n_actions: Total number of actions in the Discrete Environment. + is_backward: if False, then this is a forward policy, else backward policy. """ - super().__init__(env, module) - assert greedy_eps >= 0 - self._forward = forward - self._greedy_eps = greedy_eps - self.temperature = temperature - self.sf_bias = sf_bias - self.epsilon = epsilon - - @property - def greedy_eps(self): - return self._greedy_eps + super().__init__(module, preprocessor, is_backward=is_backward) + self.n_actions = n_actions def expected_output_dim(self) -> int: - if self._forward: - return self.env.n_actions + if self.is_backward: + return self.n_actions - 1 else: - return self.env.n_actions - 1 + return self.n_actions def to_probability_distribution( self, states: DiscreteStates, module_output: TT["batch_shape", "output_dim", float], + temperature: float = 1.0, + sf_bias: float = 0.0, + epsilon: float = 0.0, ) -> Categorical: - """Returns a probability distribution given a batch of states and module output.""" - masks = states.forward_masks if self._forward else states.backward_masks + """Returns a probability distribution given a batch of states and module output. + + Args: + temperature: scalar to divide the logits by before softmax. Does nothing + if set to 1.0 (default), in which case it's on policy. + sf_bias: scalar to subtract from the exit action logit before dividing by + temperature. Does nothing if set to 0.0 (default), in which case it's + on policy. + epsilon: with probability epsilon, a random action is chosen. Does nothing + if set to 0.0 (default), in which case it's on policy.""" + masks = states.backward_masks if self.is_backward else states.forward_masks logits = module_output logits[~masks] = -float("inf") # Forward policy supports exploration in many implementations. - if self._greedy_eps: - logits[:, -1] -= self.sf_bias - probs = torch.softmax(logits / self.temperature, dim=-1) + if temperature != 1.0 or sf_bias != 0.0 or epsilon != 0.0: + logits[:, -1] -= sf_bias + probs = torch.softmax(logits / temperature, dim=-1) uniform_dist_probs = masks.float() / masks.sum(dim=-1, keepdim=True) - probs = (1 - self.epsilon) * probs + self.epsilon * uniform_dist_probs + probs = (1 - epsilon) * probs + epsilon * uniform_dist_probs return UnsqueezedCategorical(probs=probs) From 2caa0d43311176eaeec1af94ab13424ff6e85863 Mon Sep 17 00:00:00 2001 From: Salem Date: Sun, 6 Aug 2023 15:31:03 -0400 Subject: [PATCH 02/21] Adapt to changes of Module API --- src/gfn/gym/helpers/box_utils.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/gfn/gym/helpers/box_utils.py b/src/gfn/gym/helpers/box_utils.py index 0783b701..2dfd83e8 100644 --- a/src/gfn/gym/helpers/box_utils.py +++ b/src/gfn/gym/helpers/box_utils.py @@ -359,8 +359,8 @@ class DistributionWrapper(Distribution): def __init__( self, states: States, - env: Box, delta: float, + epsilon: float, mixture_logits, alpha_r, beta_r, @@ -370,7 +370,6 @@ def __init__( n_components, n_components_s0, ): - self.env = env self.idx_is_initial = torch.where(torch.all(states.tensor == 0, 1))[0] self.idx_not_initial = torch.where(torch.any(states.tensor != 0, 1))[0] self._output_shape = states.tensor.shape @@ -387,13 +386,13 @@ def __init__( self.quarter_circ = None if len(self.idx_not_initial) > 0: self.quarter_circ = QuarterCircleWithExit( - delta=self.env.delta, + delta=delta, centers=states[self.idx_not_initial], # Remove initial states. exit_probability=exit_probability[self.idx_not_initial], mixture_logits=mixture_logits[self.idx_not_initial, :n_components], alpha=alpha_theta[self.idx_not_initial, :n_components], beta=beta_theta[self.idx_not_initial, :n_components], - epsilon=self.env.epsilon, + epsilon=epsilon, ) # no sample_shape req as it is stored in centers. def sample(self, sample_shape=()): @@ -472,6 +471,7 @@ def __init__( self.n_components = n_components input_dim = 2 + self.input_dim = input_dim output_dim = 1 + 3 * self.n_components @@ -571,6 +571,7 @@ def __init__( **kwargs: passed to the NeuralNet class. """ input_dim = 2 + self.input_dim = input_dim output_dim = 3 * n_components super().__init__( @@ -619,6 +620,8 @@ class BoxPBUniform(torch.nn.Module): uniform distribution over parents in the south-western part of circle. """ + input_dim = 2 + def forward( self, preprocessed_states: TT["batch_shape", 2, float] ) -> TT["batch_shape", 3]: @@ -680,14 +683,15 @@ def __init__( min_concentration: float = 0.1, max_concentration: float = 2.0, ): - super().__init__(env, module) + super().__init__(module) self._n_comp_max = max(n_components_s0, n_components) self.n_components_s0 = n_components_s0 self.n_components = n_components self.min_concentration = min_concentration self.max_concentration = max_concentration - self.env = env + self.delta = env.delta + self.epsilon = env.epsilon def expected_output_dim(self) -> int: return 1 + 5 * self._n_comp_max @@ -736,8 +740,8 @@ def _normalize(x): return DistributionWrapper( states, - self.env, - self.env.delta, + self.delta, + self.epsilon, mixture_logits, alpha_r, beta_r, @@ -760,13 +764,15 @@ def __init__( min_concentration: float = 0.1, max_concentration: float = 2.0, ): - super().__init__(env, module) + super().__init__(module, is_backward=True) self.module = module self.n_components = n_components self.min_concentration = min_concentration self.max_concentration = max_concentration + self.delta = env.delta + def expected_output_dim(self) -> int: return 3 * self.n_components @@ -789,7 +795,7 @@ def _normalize(x): alpha = _normalize(alpha) beta = _normalize(beta) return QuarterCircle( - delta=self.env.delta, + delta=self.delta, northeastern=False, centers=states, mixture_logits=mixture_logits, From 3d4eac224f95a882dc0a6c34e967a6a412dc7b82 Mon Sep 17 00:00:00 2001 From: Salem Date: Sun, 6 Aug 2023 15:31:31 -0400 Subject: [PATCH 03/21] env passed explicitly in many methods of GFlowNets --- src/gfn/gflownet/base.py | 17 ++++++++---- src/gfn/gflownet/detailed_balance.py | 20 ++++++++------ src/gfn/gflownet/flow_matching.py | 32 ++++++++++++---------- src/gfn/gflownet/sub_trajectory_balance.py | 19 +++++++------ src/gfn/gflownet/trajectory_balance.py | 24 ++++++++++++---- 5 files changed, 70 insertions(+), 42 deletions(-) diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index e1291a8f..b5727486 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -6,6 +6,7 @@ from torchtyping import TensorType as TT from gfn.containers import Trajectories +from gfn.env import Env from gfn.modules import GFNModule from gfn.samplers import Sampler from gfn.states import States @@ -18,30 +19,36 @@ class GFlowNet(nn.Module): """ @abstractmethod - def sample_trajectories(self, n_samples: int) -> Trajectories: + def sample_trajectories(self, env: Env, n_samples: int) -> Trajectories: """Sample a specific number of complete trajectories. Args: + env: the environment to sample trajectories from. n_samples: number of trajectories to be sampled. Returns: Trajectories: sampled trajectories object. """ - def sample_terminating_states(self, n_samples: int) -> States: + def sample_terminating_states(self, env: Env, n_samples: int) -> States: """Rolls out the parametrization's policy and returns the terminating states. Args: + env: the environment to sample terminating states from. n_samples: number of terminating states to be sampled. Returns: States: sampled terminating states object. """ - trajectories = self.sample_trajectories(n_samples) + trajectories = self.sample_trajectories(env, n_samples) return trajectories.last_states @abstractmethod def to_training_samples(self, trajectories: Trajectories): """Converts trajectories to training samples. The type depends on the GFlowNet.""" + @abstractmethod + def loss(self, env: Env, training_objects): + """Computes the loss given the training objects.""" + class PFBasedGFlowNet(GFlowNet): r"""Base class for gflownets that explicitly uses $P_F$. @@ -57,9 +64,9 @@ def __init__(self, pf: GFNModule, pb: GFNModule, on_policy: bool = False): self.pb = pb self.on_policy = on_policy - def sample_trajectories(self, n_samples: int = 1000) -> Trajectories: + def sample_trajectories(self, env: Env, n_samples: int) -> Trajectories: sampler = Sampler(estimator=self.pf) - trajectories = sampler.sample_trajectories(n_trajectories=n_samples) + trajectories = sampler.sample_trajectories(env, n_trajectories=n_samples) return trajectories diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index 02125c77..2310abaa 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -4,8 +4,9 @@ from torchtyping import TensorType as TT from gfn.containers import Trajectories, Transitions +from gfn.env import Env from gfn.gflownet.base import PFBasedGFlowNet -from gfn.modules import ScalarEstimator +from gfn.modules import GFNModule, ScalarEstimator class DBGFlowNet(PFBasedGFlowNet): @@ -27,17 +28,18 @@ class DBGFlowNet(PFBasedGFlowNet): def __init__( self, + pf: GFNModule, + pb: GFNModule, logF: ScalarEstimator, + on_policy: bool = False, forward_looking: bool = False, - **kwargs, ): - super().__init__(**kwargs) + super().__init__(pf, pb, on_policy=on_policy) self.logF = logF self.forward_looking = forward_looking - self.env = self.logF.env # TODO We don't want to store env in here... def get_scores( - self, transitions: Transitions + self, env: Env, transitions: Transitions ) -> Tuple[ TT["n_transitions", float], TT["n_transitions", float], @@ -72,7 +74,7 @@ def get_scores( valid_log_F_s = self.logF(states).squeeze(-1) if self.forward_looking: - log_rewards = self.env.log_reward(states) # RM unsqueeze(-1) + log_rewards = env.log_reward(states) # RM unsqueeze(-1) valid_log_F_s = valid_log_F_s + log_rewards preds = valid_log_pf_actions + valid_log_F_s @@ -110,12 +112,12 @@ def get_scores( return (valid_log_pf_actions, log_pb_actions, scores) - def loss(self, transitions: Transitions) -> TT[0, float]: + def loss(self, env: Env, transitions: Transitions) -> TT[0, float]: """Detailed balance loss. The detailed balance loss is described in section 3.2 of [GFlowNet Foundations](https://arxiv.org/abs/2111.09266).""" - _, _, scores = self.get_scores(transitions) + _, _, scores = self.get_scores(env, transitions) loss = torch.mean(scores**2) if torch.isnan(loss): @@ -182,7 +184,7 @@ def get_scores(self, transitions: Transitions) -> TT["n_trajectories", torch.flo return scores - def loss(self, transitions: Transitions) -> TT[0, float]: + def loss(self, env: Env, transitions: Transitions) -> TT[0, float]: """Calculates the modified detailed balance loss.""" scores = self.get_scores(transitions) return torch.mean(scores**2) diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index 998859d9..f9872732 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -4,6 +4,7 @@ from torchtyping import TensorType as TT from gfn.containers import Trajectories +from gfn.env import Env from gfn.gflownet.base import GFlowNet from gfn.modules import DiscretePolicyEstimator from gfn.samplers import Sampler @@ -28,23 +29,21 @@ class FMGFlowNet(GFlowNet): def __init__(self, logF: DiscretePolicyEstimator, alpha: float = 1.0): super().__init__() - assert not logF.greedy_eps self.logF = logF self.alpha = alpha - self.env = self.logF.env - if not self.env.is_discrete: + + def sample_trajectories(self, env: Env, n_samples: int = 1000) -> Trajectories: + if not env.is_discrete: raise NotImplementedError( "Flow Matching GFlowNet only supports discrete environments for now." ) - - def sample_trajectories(self, n_samples: int = 1000) -> Trajectories: sampler = Sampler(estimator=self.logF) - trajectories = sampler.sample_trajectories(n_trajectories=n_samples) + trajectories = sampler.sample_trajectories(env, n_trajectories=n_samples) return trajectories def flow_matching_loss( - self, states: DiscreteStates + self, env: Env, states: DiscreteStates ) -> TT["n_trajectories", torch.float]: """Computes the FM for the provided states. @@ -67,7 +66,7 @@ def flow_matching_loss( states.forward_masks, -float("inf"), dtype=torch.float ) - for action_idx in range(self.env.n_actions - 1): + for action_idx in range(env.n_actions - 1): valid_backward_mask = states.backward_masks[:, action_idx] valid_forward_mask = states.forward_masks[:, action_idx] valid_backward_states = states[valid_backward_mask] @@ -76,9 +75,9 @@ def flow_matching_loss( backward_actions = torch.full_like( valid_backward_states.backward_masks[:, 0], action_idx, dtype=torch.long ).unsqueeze(-1) - backward_actions = self.env.Actions(backward_actions) + backward_actions = env.Actions(backward_actions) - valid_backward_states_parents = self.env.backward_step( + valid_backward_states_parents = env.backward_step( valid_backward_states, backward_actions ) @@ -101,8 +100,11 @@ def flow_matching_loss( return (log_incoming_flows - log_outgoing_flows).pow(2).mean() - def reward_matching_loss(self, terminating_states: DiscreteStates) -> TT[0, float]: + def reward_matching_loss( + self, env: Env, terminating_states: DiscreteStates + ) -> TT[0, float]: """Calculates the reward matching loss from the terminating states.""" + del env # Unused assert terminating_states.log_rewards is not None log_edge_flows = self.logF(terminating_states) @@ -111,7 +113,9 @@ def reward_matching_loss(self, terminating_states: DiscreteStates) -> TT[0, floa log_rewards = terminating_states.log_rewards return (terminating_log_edge_flows - log_rewards).pow(2).mean() - def loss(self, states_tuple: Tuple[DiscreteStates, DiscreteStates]) -> TT[0, float]: + def loss( + self, env: Env, states_tuple: Tuple[DiscreteStates, DiscreteStates] + ) -> TT[0, float]: """Given a batch of non-terminal and terminal states, compute a loss. Unlike the GFlowNets Foundations paper, we allow more flexibility by passing a @@ -119,8 +123,8 @@ def loss(self, states_tuple: Tuple[DiscreteStates, DiscreteStates]) -> TT[0, flo (i.e. non-terminal states), and the second one being the terminal states of the trajectories.""" intermediary_states, terminating_states = states_tuple - fm_loss = self.flow_matching_loss(intermediary_states) - rm_loss = self.reward_matching_loss(terminating_states) + fm_loss = self.flow_matching_loss(env, intermediary_states) + rm_loss = self.reward_matching_loss(env, terminating_states) return fm_loss + self.alpha * rm_loss def to_training_samples( diff --git a/src/gfn/gflownet/sub_trajectory_balance.py b/src/gfn/gflownet/sub_trajectory_balance.py index bd83e707..175be074 100644 --- a/src/gfn/gflownet/sub_trajectory_balance.py +++ b/src/gfn/gflownet/sub_trajectory_balance.py @@ -4,8 +4,9 @@ from torchtyping import TensorType as TT from gfn.containers import Trajectories -from gfn.gflownet.base import PFBasedGFlowNet, TrajectoryBasedGFlowNet -from gfn.modules import ScalarEstimator +from gfn.env import Env +from gfn.gflownet.base import TrajectoryBasedGFlowNet +from gfn.modules import GFNModule, ScalarEstimator class SubTBGFlowNet(TrajectoryBasedGFlowNet): @@ -43,7 +44,10 @@ class SubTBGFlowNet(TrajectoryBasedGFlowNet): def __init__( self, + pf: GFNModule, + pb: GFNModule, logF: ScalarEstimator, + on_policy: bool = False, weighting: Literal[ "DB", "ModifiedDB", @@ -56,9 +60,8 @@ def __init__( lamda: float = 0.9, log_reward_clip_min: float = -12, # roughly log(1e-5) forward_looking: bool = False, - **kwargs, ): - super().__init__(**kwargs) + super().__init__(pf, pb, on_policy=on_policy) self.logF = logF self.weighting = weighting self.lamda = lamda @@ -89,7 +92,7 @@ def cumulative_logprobs( ) def get_scores( - self, trajectories: Trajectories + self, env: Env, trajectories: Trajectories ) -> Tuple[List[TT[0, float]], List[TT[0, float]]]: """Scores all submitted trajectories. @@ -123,7 +126,7 @@ def get_scores( log_F = self.logF(valid_states).squeeze(-1) if self.forward_looking: - log_rewards = self.logF.env.log_reward(states).unsqueeze(-1) + log_rewards = env.log_reward(states).unsqueeze(-1) log_F = log_F + log_rewards log_state_flows[mask[:-1]] = log_F @@ -188,9 +191,9 @@ def get_scores( flattening_masks, ) - def loss(self, trajectories: Trajectories) -> TT[0, float]: + def loss(self, env: Env, trajectories: Trajectories) -> TT[0, float]: # Get all scores and masks from the trajectories. - scores, flattening_masks = self.get_scores(trajectories) + scores, flattening_masks = self.get_scores(env, trajectories) flattening_mask = torch.cat(flattening_masks) all_scores = torch.cat(scores, 0) diff --git a/src/gfn/gflownet/trajectory_balance.py b/src/gfn/gflownet/trajectory_balance.py index b638789d..bceac033 100644 --- a/src/gfn/gflownet/trajectory_balance.py +++ b/src/gfn/gflownet/trajectory_balance.py @@ -8,7 +8,9 @@ from torchtyping import TensorType as TT from gfn.containers import Trajectories +from gfn.env import Env from gfn.gflownet.base import TrajectoryBasedGFlowNet +from gfn.modules import GFNModule class TBGFlowNet(TrajectoryBasedGFlowNet): @@ -28,16 +30,18 @@ class TBGFlowNet(TrajectoryBasedGFlowNet): def __init__( self, + pf: GFNModule, + pb: GFNModule, + on_policy: bool = False, init_logZ: float = 0.0, log_reward_clip_min: float = -12, # roughly log(1e-5) - **kwargs, ): - super().__init__(**kwargs) + super().__init__(pf, pb, on_policy=on_policy) self.logZ = nn.Parameter(torch.tensor(init_logZ)) self.log_reward_clip_min = log_reward_clip_min - def loss(self, trajectories: Trajectories) -> TT[0, float]: + def loss(self, env: Env, trajectories: Trajectories) -> TT[0, float]: """Trajectory balance loss. The trajectory balance loss is described in 2.3 of @@ -46,6 +50,7 @@ def loss(self, trajectories: Trajectories) -> TT[0, float]: Raises: ValueError: if the loss is NaN. """ + del env # unused _, _, scores = self.get_trajectories_scores(trajectories) loss = (scores + self.logZ).pow(2).mean() if torch.isnan(loss): @@ -64,17 +69,24 @@ class LogPartitionVarianceGFlowNet(TrajectoryBasedGFlowNet): ValueError: if the loss is NaN. """ - def __init__(self, log_reward_clip_min: float = -12, **kwargs): - super().__init__(**kwargs) + def __init__( + self, + pf: GFNModule, + pb: GFNModule, + on_policy: bool = False, + log_reward_clip_min: float = -12, + ): + super().__init__(pf, pb, on_policy=on_policy) self.log_reward_clip_min = log_reward_clip_min # -12 is roughly log(1e-5) - def loss(self, trajectories: Trajectories) -> TT[0, float]: + def loss(self, env: Env, trajectories: Trajectories) -> TT[0, float]: """Log Partition Variance loss. This method is described in section 3.2 of [ROBUST SCHEDULING WITH GFLOWNETS](https://arxiv.org/abs/2302.05446)) """ + del env # unused _, _, scores = self.get_trajectories_scores(trajectories) loss = (scores - scores.mean()).pow(2).mean() if torch.isnan(loss): From 7eab2dc005e8c1104b986a8fac48e20aef5f5703 Mon Sep 17 00:00:00 2001 From: Salem Date: Sun, 6 Aug 2023 15:31:50 -0400 Subject: [PATCH 04/21] env is not part of the sampler anymore --- src/gfn/samplers.py | 65 ++++++++++++++++++++++++++------------------- 1 file changed, 38 insertions(+), 27 deletions(-) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 8868528c..83d98221 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -5,11 +5,11 @@ from gfn.actions import Actions from gfn.containers import Trajectories +from gfn.env import Env from gfn.modules import GFNModule from gfn.states import States -# TODO: Environment should not live inside the estimator and here... needs refactor. class Sampler: """`Sampler is a container for a PolicyEstimator. @@ -18,23 +18,26 @@ class Sampler: Attributes: estimator: the submitted PolicyEstimator. - env: the Environment instance inside the PolicyEstimator. - is_backward: if True, samples trajectories of actions backward (a distribution - over parents). If True, the estimator must be a ProbabilityDistribution - over parents. + probability_distribution_kwargs: keyword arguments to be passed to the `to_probability_distribution` + method of the estimator. For example, for DiscretePolicyEstimators, the kwargs can contain + the `temperature` parameter, `epsilon`, and `sf_bias`. """ - def __init__(self, estimator: GFNModule, is_backward: bool = False) -> None: + def __init__( + self, + estimator: GFNModule, + **probability_distribution_kwargs: Optional[dict], + ) -> None: self.estimator = estimator - self.env = estimator.env - self.is_backward = is_backward # TODO: take directly from estimator. + self.probability_distribution_kwargs = probability_distribution_kwargs def sample_actions( - self, states: States + self, env: Env, states: States ) -> Tuple[Actions, TT["batch_shape", torch.float]]: """Samples actions from the given states. Args: + env: The environment to sample actions from. states (States): A batch of states. Returns: @@ -45,7 +48,9 @@ def sample_actions( states. """ module_output = self.estimator(states) - dist = self.estimator.to_probability_distribution(states, module_output) + dist = self.estimator.to_probability_distribution( + states, module_output, **self.probability_distribution_kwargs + ) with torch.no_grad(): actions = dist.sample() @@ -53,16 +58,18 @@ def sample_actions( if torch.any(torch.isinf(log_probs)): raise RuntimeError("Log probabilities are inf. This should not happen.") - return self.env.Actions(actions), log_probs + return env.Actions(actions), log_probs def sample_trajectories( self, + env: Env, states: Optional[States] = None, n_trajectories: Optional[int] = None, ) -> Trajectories: """Sample trajectories sequentially. Args: + env: The environment to sample trajectories from. states: If given, trajectories would start from such states. Otherwise, trajectories are sampled from $s_o$ and n_trajectories must be provided. n_trajectories: If given, a batch of n_trajectories will be sampled all @@ -78,7 +85,7 @@ def sample_trajectories( assert ( n_trajectories is not None ), "Either states or n_trajectories should be specified" - states = self.env.reset(batch_shape=(n_trajectories,)) + states = env.reset(batch_shape=(n_trajectories,)) else: assert ( len(states.batch_shape) == 1 @@ -87,7 +94,11 @@ def sample_trajectories( device = states.tensor.device - dones = states.is_initial_state if self.is_backward else states.is_sink_state + dones = ( + states.is_initial_state + if self.estimator.is_backward + else states.is_sink_state + ) trajectories_states: List[TT["n_trajectories", "state_shape", torch.float]] = [ states.tensor @@ -104,37 +115,37 @@ def sample_trajectories( step = 0 while not all(dones): - actions = self.env.Actions.make_dummy_actions(batch_shape=(n_trajectories,)) + actions = env.Actions.make_dummy_actions(batch_shape=(n_trajectories,)) log_probs = torch.full( (n_trajectories,), fill_value=0, dtype=torch.float, device=device ) - valid_actions, actions_log_probs = self.sample_actions(states[~dones]) + valid_actions, actions_log_probs = self.sample_actions(env, states[~dones]) actions[~dones] = valid_actions log_probs[~dones] = actions_log_probs trajectories_actions += [actions] trajectories_logprobs += [log_probs] - if self.is_backward: - new_states = self.env.backward_step(states, actions) + if self.estimator.is_backward: + new_states = env.backward_step(states, actions) else: - new_states = self.env.step(states, actions) + new_states = env.step(states, actions) sink_states_mask = new_states.is_sink_state step += 1 new_dones = ( - new_states.is_initial_state if self.is_backward else sink_states_mask + new_states.is_initial_state + if self.estimator.is_backward + else sink_states_mask ) & ~dones trajectories_dones[new_dones & ~dones] = step try: - trajectories_log_rewards[new_dones & ~dones] = self.env.log_reward( + trajectories_log_rewards[new_dones & ~dones] = env.log_reward( states[new_dones & ~dones] ) except NotImplementedError: - # print(states[new_dones & ~dones]) - # print(torch.log(self.env.reward(states[new_dones & ~dones]))) trajectories_log_rewards[new_dones & ~dones] = torch.log( - self.env.reward(states[new_dones & ~dones]) + env.reward(states[new_dones & ~dones]) ) states = new_states dones = dones | new_dones @@ -142,16 +153,16 @@ def sample_trajectories( trajectories_states += [states.tensor] trajectories_states = torch.stack(trajectories_states, dim=0) - trajectories_states = self.env.States(tensor=trajectories_states) - trajectories_actions = self.env.Actions.stack(trajectories_actions) + trajectories_states = env.States(tensor=trajectories_states) + trajectories_actions = env.Actions.stack(trajectories_actions) trajectories_logprobs = torch.stack(trajectories_logprobs, dim=0) trajectories = Trajectories( - env=self.env, + env=env, states=trajectories_states, actions=trajectories_actions, when_is_done=trajectories_dones, - is_backward=self.is_backward, + is_backward=self.estimator.is_backward, log_rewards=trajectories_log_rewards, log_probs=trajectories_logprobs, ) From 8f251fb689f9911420f7cdaffafd0d21259da7fd Mon Sep 17 00:00:00 2001 From: Salem Date: Sun, 6 Aug 2023 15:32:04 -0400 Subject: [PATCH 05/21] adapt scripts to api changes --- tutorials/examples/train_box.py | 9 +++++---- tutorials/examples/train_discreteebm.py | 10 ++++++--- tutorials/examples/train_hypergrid.py | 27 +++++++++++++++++++------ 3 files changed, 33 insertions(+), 13 deletions(-) diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index 30d6d854..4a3ce508 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -297,7 +297,7 @@ def estimate_jsd(kde1, kde2): torso=None, # We do not tie the parameters of the flow function to PF logZ_value=logZ, ) - logF_estimator = ScalarEstimator(env=env, module=module) + logF_estimator = ScalarEstimator(module=module, preprocessor=env.preprocessor) if args.loss == "DB": gflownet = DBGFlowNet( @@ -378,16 +378,17 @@ def estimate_jsd(kde1, kde2): if iteration % 1000 == 0: print(f"current optimizer LR: {optimizer.param_groups[0]['lr']}") - trajectories = gflownet.sample_trajectories(n_samples=args.batch_size) + trajectories = gflownet.sample_trajectories(env, n_samples=args.batch_size) training_samples = gflownet.to_training_samples(trajectories) optimizer.zero_grad() - loss = gflownet.loss(training_samples) + loss = gflownet.loss(env, training_samples) loss.backward() for p in gflownet.parameters(): - p.grad.data.clamp_(-10, 10).nan_to_num_(0.0) + if p.grad is not None: + p.grad.data.clamp_(-10, 10).nan_to_num_(0.0) optimizer.step() scheduler.step() diff --git a/tutorials/examples/train_discreteebm.py b/tutorials/examples/train_discreteebm.py index 84191f2d..0dcd02a8 100644 --- a/tutorials/examples/train_discreteebm.py +++ b/tutorials/examples/train_discreteebm.py @@ -128,7 +128,11 @@ hidden_dim=args.hidden_dim, n_hidden_layers=args.n_hidden, ) - estimator = DiscretePolicyEstimator(env=env, module=module, forward=True) + estimator = DiscretePolicyEstimator( + module=module, + n_actions=env.n_actions, + preprocessor=env.preprocessor, + ) gflownet = FMGFlowNet(estimator) # 3. Create the optimizer @@ -141,11 +145,11 @@ states_visited = 0 n_iterations = args.n_trajectories // args.batch_size for iteration in trange(n_iterations): - trajectories = gflownet.sample_trajectories(n_samples=args.batch_size) + trajectories = gflownet.sample_trajectories(env, n_samples=args.batch_size) training_samples = gflownet.to_training_samples(trajectories) optimizer.zero_grad() - loss = gflownet.loss(training_samples) + loss = gflownet.loss(env, training_samples) loss.backward() optimizer.step() diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 5a613faf..242296b5 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -181,7 +181,11 @@ hidden_dim=args.hidden_dim, n_hidden_layers=args.n_hidden, ) - estimator = DiscretePolicyEstimator(env=env, module=module, forward=True) + estimator = DiscretePolicyEstimator( + module=module, + n_actions=env.n_actions, + preprocessor=env.preprocessor, + ) gflownet = FMGFlowNet(estimator) else: pb_module = None @@ -215,8 +219,17 @@ pb_module is not None ), f"pb_module is None. Command-line arguments: {args}" - pf_estimator = DiscretePolicyEstimator(env=env, module=pf_module, forward=True) - pb_estimator = DiscretePolicyEstimator(env=env, module=pb_module, forward=False) + pf_estimator = DiscretePolicyEstimator( + module=pf_module, + n_actions=env.n_actions, + preprocessor=env.preprocessor, + ) + pb_estimator = DiscretePolicyEstimator( + module=pb_module, + n_actions=env.n_actions, + is_backward=True, + preprocessor=env.preprocessor, + ) if args.loss == "ModifiedDB": gflownet = ModifiedDBGFlowNet( @@ -245,7 +258,9 @@ torso=pf_module.torso if args.tied else None, ) - logF_estimator = ScalarEstimator(env=env, module=module) + logF_estimator = ScalarEstimator( + module=module, preprocessor=env.preprocessor + ) if args.loss == "DB": gflownet = DBGFlowNet( pf=pf_estimator, @@ -321,7 +336,7 @@ states_visited = 0 n_iterations = args.n_trajectories // args.batch_size for iteration in trange(n_iterations): - trajectories = gflownet.sample_trajectories(n_samples=args.batch_size) + trajectories = gflownet.sample_trajectories(env, n_samples=args.batch_size) training_samples = gflownet.to_training_samples(trajectories) if replay_buffer is not None: with torch.no_grad(): @@ -331,7 +346,7 @@ training_objects = training_samples optimizer.zero_grad() - loss = gflownet.loss(training_objects) + loss = gflownet.loss(env, training_objects) loss.backward() optimizer.step() From 863bcfef8d39dfb521f65e4f0c43ed1bb868312f Mon Sep 17 00:00:00 2001 From: Salem Date: Sun, 6 Aug 2023 15:32:11 -0400 Subject: [PATCH 06/21] adapt tests to api changes --- testing/test_parametrizations_and_losses.py | 32 ++++++++++++--------- testing/test_samplers_and_trajectories.py | 22 ++++++++++---- 2 files changed, 35 insertions(+), 19 deletions(-) diff --git a/testing/test_parametrizations_and_losses.py b/testing/test_parametrizations_and_losses.py index 773a7583..ac7ffb5d 100644 --- a/testing/test_parametrizations_and_losses.py +++ b/testing/test_parametrizations_and_losses.py @@ -18,7 +18,7 @@ BoxPFEstimator, BoxPFNeuralNet, ) -from gfn.modules import DiscretePolicyEstimator, GFNModule, ScalarEstimator +from gfn.modules import DiscretePolicyEstimator, ScalarEstimator from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular @@ -51,15 +51,15 @@ def test_FM(env_name: int, ndim: int, module_name: str): raise ValueError("Unknown module name") log_F_edge = DiscretePolicyEstimator( - env=env, module=module, - forward=True, + n_actions=env.n_actions, + preprocessor=env.preprocessor, ) gflownet = FMGFlowNet(log_F_edge) # forward looking by default. - trajectories = gflownet.sample_trajectories(n_samples=10) + trajectories = gflownet.sample_trajectories(env, n_samples=10) states_tuple = trajectories.to_non_initial_intermediary_and_terminating_states() - loss = gflownet.loss(states_tuple) + loss = gflownet.loss(env, states_tuple) assert loss >= 0 @@ -174,10 +174,14 @@ def PFBasedGFlowNet_with_return( n_components=ndim + 1 if module_name != "Uniform" else 1, ) else: - pf = DiscretePolicyEstimator(env, pf_module, forward=True) - pb = DiscretePolicyEstimator(env, pb_module, forward=False) + pf = DiscretePolicyEstimator( + pf_module, env.n_actions, preprocessor=env.preprocessor + ) + pb = DiscretePolicyEstimator( + pb_module, env.n_actions, preprocessor=env.preprocessor, is_backward=True + ) - logF = ScalarEstimator(env, module=logF_module) + logF = ScalarEstimator(module=logF_module, preprocessor=env.preprocessor) if gflownet_name == "DB": gflownet = DBGFlowNet( @@ -202,10 +206,10 @@ def PFBasedGFlowNet_with_return( else: raise ValueError(f"Unknown gflownet {gflownet_name}") - trajectories = gflownet.sample_trajectories(10) + trajectories = gflownet.sample_trajectories(env, 10) training_objects = gflownet.to_training_samples(trajectories) - _ = gflownet.loss(training_objects) + _ = gflownet.loss(env, training_objects) if gflownet_name == "TB": assert torch.all( @@ -299,9 +303,11 @@ def test_subTB_vs_TB( zero_logF=True, ) - trajectories = gflownet.sample_trajectories(10) - subtb_loss = gflownet.loss(trajectories) + trajectories = gflownet.sample_trajectories(env, 10) + subtb_loss = gflownet.loss(env, trajectories) if weighting == "TB": - tb_loss = TBGFlowNet(pf=pf, pb=pb).loss(trajectories) # LogZ is default 0.0. + tb_loss = TBGFlowNet(pf=pf, pb=pb).loss( + env, trajectories + ) # LogZ is default 0.0. assert (tb_loss - subtb_loss).abs() < 1e-4 diff --git a/testing/test_samplers_and_trajectories.py b/testing/test_samplers_and_trajectories.py index cfa65197..a871d940 100644 --- a/testing/test_samplers_and_trajectories.py +++ b/testing/test_samplers_and_trajectories.py @@ -65,16 +65,26 @@ def trajectory_sampling_with_return( pb_module = NeuralNet( input_dim=env.preprocessor.output_dim, output_dim=env.n_actions - 1 ) - pf_estimator = DiscretePolicyEstimator(env=env, module=pf_module, forward=True) - pb_estimator = DiscretePolicyEstimator(env=env, module=pb_module, forward=False) + pf_estimator = DiscretePolicyEstimator( + module=pf_module, + n_actions=env.n_actions, + is_backward=False, + preprocessor=env.preprocessor, + ) + pb_estimator = DiscretePolicyEstimator( + module=pb_module, + n_actions=env.n_actions, + is_backward=True, + preprocessor=env.preprocessor, + ) sampler = Sampler(estimator=pf_estimator) - trajectories = sampler.sample_trajectories(n_trajectories=5) - trajectories = sampler.sample_trajectories(n_trajectories=10) + trajectories = sampler.sample_trajectories(env, n_trajectories=5) + trajectories = sampler.sample_trajectories(env, n_trajectories=10) states = env.reset(batch_shape=5, random=True) - bw_sampler = Sampler(estimator=pb_estimator, is_backward=True) - bw_trajectories = bw_sampler.sample_trajectories(states) + bw_sampler = Sampler(estimator=pb_estimator) + bw_trajectories = bw_sampler.sample_trajectories(env, states) return trajectories, bw_trajectories, pf_estimator, pb_estimator From 4fd4a3e32056640cac2269c7c5b68f1452b41fc8 Mon Sep 17 00:00:00 2001 From: Salem Date: Sun, 6 Aug 2023 15:32:17 -0400 Subject: [PATCH 07/21] adapt readme to api changes --- README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 627f933f..44f3f0b9 100644 --- a/README.md +++ b/README.md @@ -81,12 +81,12 @@ if __name__ == "__main__": torso=module_PF.torso ) - pf_estimator = DiscretePolicyEstimator(env, module_PF, forward=True) - pb_estimator = DiscretePolicyEstimator(env, module_PB, forward=False) + pf_estimator = DiscretePolicyEstimator(module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor) + pb_estimator = DiscretePolicyEstimator(module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor) gfn = TBGFlowNet(init_logZ=0., pf=pf_estimator, pb=pb_estimator) - sampler = Sampler(estimator=pf_estimator)) + sampler = Sampler(estimator=pf_estimator) # Policy parameters have their own LR. non_logz_params = [v for k, v in dict(gfn.named_parameters()).items() if k != "logZ"] @@ -97,9 +97,9 @@ if __name__ == "__main__": optimizer.add_param_group({"params": logz_params, "lr": 1e-2}) for i in (pbar := tqdm(range(1000))): - trajectories = sampler.sample_trajectories(n_trajectories=16) + trajectories = sampler.sample_trajectories(env=env, n_trajectories=16) optimizer.zero_grad() - loss = gfn.loss(trajectories) + loss = gfn.loss(env, trajectories) loss.backward() optimizer.step() if i % 25 == 0: @@ -171,8 +171,8 @@ In most cases, one needs to sample complete trajectories. From a batch of trajec ### Modules -Training GFlowNets requires one or multiple estimators, called `GFNModule`s, which is an abstract subclass of `torch.nn.Module`. In addition to the usual `forward` function, `GFNModule`s need to implement a `required_output_dim` attribute, to ensure that the outputs have the required dimension for the task at hand; and some (but not all) need to implement a `to_probability_distribution` function. They take the environment `env` as an input at initialization. -- `DiscretePolicyEstimator` is a `GFNModule` that defines the policies $P_F(. \mid s)$ and $P_B(. \mid s)$ for discrete environments. When `backward=False`, the required output dimension is `n = env.n_actions`, and when `backward=True`, it is `n = env.n_actions - 1`. These `n` numbers represent the logits of a Categorical distribution. Additionally, they include exploration parameters, in order to define a tempered version of $P_F$, or a mixture of $P_F$ with a uniform distribution. Naturally, before defining the Categorical distributions, forbidden actions (that are encoded in the `DiscreteStates`' masks attributes), are given 0 probability by setting the corresponding logit to $-\infty$. +Training GFlowNets requires one or multiple estimators, called `GFNModule`s, which is an abstract subclass of `torch.nn.Module`. In addition to the usual `forward` function, `GFNModule`s need to implement a `required_output_dim` attribute, to ensure that the outputs have the required dimension for the task at hand; and some (but not all) need to implement a `to_probability_distribution` function. +- `DiscretePolicyEstimator` is a `GFNModule` that defines the policies $P_F(. \mid s)$ and $P_B(. \mid s)$ for discrete environments. When `is_backward=False`, the required output dimension is `n = env.n_actions`, and when `is_backward=True`, it is `n = env.n_actions - 1`. These `n` numbers represent the logits of a Categorical distribution. Additionally, they include exploration parameters, in order to define a tempered version of $P_F$, or a mixture of $P_F$ with a uniform distribution. Naturally, before defining the Categorical distributions, forbidden actions (that are encoded in the `DiscreteStates`' masks attributes), are given 0 probability by setting the corresponding logit to $-\infty$. - `ScalarModule` is a simple module with required output dimension 1. It is useful to define log-state flows $\log F(s)$. For non-discrete environments, the user needs to specify their own policies $P_F$ and $P_B$. The module, taking as input a batch of states (as a `States`) object, should return the batched parameters of a `torch.Distribution`. The distribution depends on the environment. The `to_probability_distribution` function handles the conversion of the parameter outputs to an actual batched `Distribution` object, that implements at least the `sample` and `log_prob` functions. An example is provided [here](https://github.com/saleml/torchgfn/tree/master/src/gfn/gym/helpers/box_utils.py), for a square environment in which the forward policy has support either on a quarter disk, or on an arc-circle, such that the angle, and the radius (for the quarter disk part) are scaled samples from a mixture of Beta distributions. The provided example shows an intricate scenario, and it is not expected that user defined environment need this much level of details. From a7b056e134b2e8e04032fb3f7177c1f9f5b0099d Mon Sep 17 00:00:00 2001 From: Salem Date: Sun, 6 Aug 2023 16:15:52 -0400 Subject: [PATCH 08/21] add forgotten env --- tutorials/examples/train_box.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index 4a3ce508..2e0b21b2 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -409,7 +409,7 @@ def estimate_jsd(kde1, kde2): if iteration % args.validation_interval == 0: validation_samples = gflownet.sample_terminating_states( - args.validation_samples + env, args.validation_samples ) kde = KernelDensity(kernel="exponential", bandwidth=0.1).fit( validation_samples.tensor.detach().cpu().numpy() From 815e95c237b2a924b4129fc0a188c1a9bc8024cc Mon Sep 17 00:00:00 2001 From: Salem Date: Sun, 6 Aug 2023 16:31:41 -0400 Subject: [PATCH 09/21] add tests for the scripts --- src/gfn/gym/helpers/test_box_utils.py | 7 +- tutorials/examples/test_scripts.py | 112 +++++++++ tutorials/examples/train_box.py | 306 ++++++++++++------------ tutorials/examples/train_discreteebm.py | 142 +++++------ tutorials/examples/train_hypergrid.py | 240 ++++++++++--------- 5 files changed, 471 insertions(+), 336 deletions(-) create mode 100644 tutorials/examples/test_scripts.py diff --git a/src/gfn/gym/helpers/test_box_utils.py b/src/gfn/gym/helpers/test_box_utils.py index bc011d92..aefd12e2 100644 --- a/src/gfn/gym/helpers/test_box_utils.py +++ b/src/gfn/gym/helpers/test_box_utils.py @@ -1,4 +1,5 @@ import torch +import pytest from gfn.gym import Box from gfn.gym.helpers.box_utils import ( @@ -13,14 +14,14 @@ ) -def test_mixed_distributions(): +@pytest.mark.parametrize("n_components", [5, 6]) +@pytest.mark.parametrize("n_components_s0", [5, 6]) +def test_mixed_distributions(n_components: int, n_components_s0: int): """Ensure DistributionWrapper functions correctly.""" delta = 0.1 hidden_dim = 10 n_hidden_layers = 2 - n_components = 5 - n_components_s0 = 6 environment = Box( delta=delta, diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py new file mode 100644 index 00000000..cf16e319 --- /dev/null +++ b/tutorials/examples/test_scripts.py @@ -0,0 +1,112 @@ +from dataclasses import dataclass +import pytest + +from .train_box import main as train_box_main +from .train_discreteebm import main as train_discreteebm_main +from .train_hypergrid import main as train_hypergrid_main + + +@dataclass +class CommonArgs: + no_cuda: bool = True + seed: int = 0 + batch_size: int = 16 + replay_buffer_size: int = 0 + loss: str = "TB" + subTB_weighting: str = "geometric_within" + subTB_lambda: float = 0.9 + tabular: bool = False + uniform_pb: bool = False + tied: bool = False + hidden_dim: int = 256 + n_hidden: int = 2 + lr: float = 1e-3 + lr_Z: float = 1e-1 + n_trajectories: int = 32000 + validation_interval: int = 100 + validation_samples: int = 200000 + wandb_project: str = "" + + +@dataclass +class DiscreteEBMArgs(CommonArgs): + ndim: int = 4 + alpha: float = 1.0 + + +@dataclass +class HypergridArgs(CommonArgs): + ndim: int = 2 + height: int = 8 + R0: float = 0.1 + R1: float = 0.5 + R2: float = 2.0 + + +@dataclass +class BoxArgs(CommonArgs): + delta: float = 0.25 + min_concentration: float = 0.1 + max_concentration: float = 5.1 + n_components: int = 2 + n_components_s0: int = 2 + gamma_scheduler: float = 0.5 + scheduler_milestone: int = 2500 + lr_F: float = 1e-2 + + +@pytest.mark.parametrize("ndim", [2, 4]) +@pytest.mark.parametrize("height", [8, 16]) +def test_hypergrid(ndim: int, height: int): + n_trajectories = 32000 if ndim == 2 else 16000 + args = HypergridArgs(ndim=ndim, height=height, n_trajectories=n_trajectories) + final_l1_dist = train_hypergrid_main(args) + if ndim == 2 and height == 8: + assert final_l1_dist < 1e-3 + elif ndim == 2 and height == 16: + assert final_l1_dist < 6e-4 + elif ndim == 4 and height == 8: + assert final_l1_dist < 2e-4 + elif ndim == 4 and height == 16: + assert final_l1_dist < 3e-5 + + +@pytest.mark.parametrize("ndim", [2, 4]) +@pytest.mark.parametrize("alpha", [0.1, 1.0]) +def test_discreteebm(ndim: int, alpha: float): + n_trajectories = 16000 + args = DiscreteEBMArgs(ndim=ndim, alpha=alpha, n_trajectories=n_trajectories) + final_l1_dist = train_discreteebm_main(args) + if ndim == 2 and alpha == 0.1: + assert final_l1_dist < 6e-3 + elif ndim == 2 and alpha == 1.0: + assert final_l1_dist < 3e-2 + elif ndim == 4 and alpha == 0.1: + assert final_l1_dist < 9e-3 + elif ndim == 4 and alpha == 1.0: + assert final_l1_dist < 7e-2 + + +@pytest.mark.parametrize("delta", [0.1, 0.25]) +@pytest.mark.parametrize("loss", ["TB", "DB"]) +def test_box(delta: float, loss: str): + n_trajectories = 128128 + validation_interval = 500 + args = BoxArgs( + delta=delta, + loss=loss, + n_trajectories=n_trajectories, + hidden_dim=128, + n_hidden=4, + batch_size=128, + validation_interval=validation_interval, + ) + final_jsd = train_box_main(args) + if loss == "TB" and delta == 0.1: + assert final_jsd < 7e-2 + elif loss == "DB" and delta == 0.1: + assert final_jsd < 0.2 + if loss == "TB" and delta == 0.25: + assert final_jsd < 1e-3 + elif loss == "DB" and delta == 0.25: + assert final_jsd < 4e-2 diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index 2e0b21b2..6d32f254 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -82,156 +82,7 @@ def estimate_jsd(kde1, kde2): return jsd / 2.0 -if __name__ == "__main__": # noqa: C901 - parser = ArgumentParser() - - parser.add_argument("--no_cuda", action="store_true", help="Prevent CUDA usage") - - parser.add_argument( - "--delta", - type=float, - default=0.25, - help="maximum distance between two successive states", - ) - - parser.add_argument( - "--seed", - type=int, - default=0, - help="Random seed, if 0 then a random seed is used", - ) - parser.add_argument( - "--batch_size", - type=int, - default=128, - help="Batch size, i.e. number of trajectories to sample per training iteration", - ) - - parser.add_argument( - "--loss", - type=str, - choices=["TB", "DB", "SubTB", "ZVar"], - default="TB", - help="Loss function to use", - ) - parser.add_argument( - "--subTB_weighting", - type=str, - default="geometric_within", - help="weighting scheme for SubTB", - ) - parser.add_argument( - "--subTB_lambda", type=float, default=0.9, help="Lambda parameter for SubTB" - ) - - parser.add_argument( - "--min_concentration", - type=float, - default=0.1, - help="minimal value for the Beta concentration parameters", - ) - - parser.add_argument( - "--max_concentration", - type=float, - default=5.1, - help="maximal value for the Beta concentration parameters", - ) - - parser.add_argument( - "--n_components", - type=int, - default=2, - help="Number of Beta distributions for P_F(s'|s)", - ) - parser.add_argument( - "--n_components_s0", - type=int, - default=4, - help="Number of Beta distributions for P_F(s'|s_0)", - ) - - parser.add_argument("--uniform_pb", action="store_true", help="Use a uniform PB") - parser.add_argument( - "--tied", - action="store_true", - help="Tie the parameters of PF, PB. F is never tied.", - ) - parser.add_argument( - "--hidden_dim", - type=int, - default=128, - help="Hidden dimension of the estimators' neural network modules.", - ) - parser.add_argument( - "--n_hidden", - type=int, - default=4, - help="Number of hidden layers (of size `hidden_dim`) in the estimators'" - + " neural network modules", - ) - - parser.add_argument( - "--lr", - type=float, - default=1e-3, - help="Learning rate for the estimators' modules", - ) - parser.add_argument( - "--lr_Z", - type=float, - default=1e-3, - help="Specific learning rate for logZ", - ) - parser.add_argument( - "--lr_F", - type=float, - default=1e-2, - help="Specific learning rate for the state flow function (only used for DB and SubTB losses)", - ) - parser.add_argument( - "--gamma_scheduler", - type=float, - default=0.5, - help="Every scheduler_milestone steps, multiply the learning rate by gamma_scheduler", - ) - parser.add_argument( - "--scheduler_milestone", - type=int, - default=2500, - help="Every scheduler_milestone steps, multiply the learning rate by gamma_scheduler", - ) - - parser.add_argument( - "--n_trajectories", - type=int, - default=int(3e6), - help="Total budget of trajectories to train on. " - + "Training iterations = n_trajectories // batch_size", - ) - - parser.add_argument( - "--validation_interval", - type=int, - default=500, - help="How often (in training steps) to validate the gflownet", - ) - parser.add_argument( - "--validation_samples", - type=int, - default=10000, - help="Number of validation samples to use to evaluate the probability mass function.", - ) - - parser.add_argument( - "--wandb_project", - type=str, - default="", - help="Name of the wandb project. If empty, don't use wandb", - ) - - args = parser.parse_args() - +def main(args): seed = args.seed if args.seed != 0 else torch.randint(int(10e10), (1,))[0].item() torch.manual_seed(seed) @@ -420,3 +271,158 @@ def estimate_jsd(kde1, kde2): wandb.log({"JSD": jsd}, step=iteration) to_log.update({"JSD": jsd}) + + return jsd + + +if __name__ == "__main__": # noqa: C901 + parser = ArgumentParser() + + parser.add_argument("--no_cuda", action="store_true", help="Prevent CUDA usage") + + parser.add_argument( + "--delta", + type=float, + default=0.25, + help="maximum distance between two successive states", + ) + + parser.add_argument( + "--seed", + type=int, + default=0, + help="Random seed, if 0 then a random seed is used", + ) + parser.add_argument( + "--batch_size", + type=int, + default=128, + help="Batch size, i.e. number of trajectories to sample per training iteration", + ) + + parser.add_argument( + "--loss", + type=str, + choices=["TB", "DB", "SubTB", "ZVar"], + default="TB", + help="Loss function to use", + ) + parser.add_argument( + "--subTB_weighting", + type=str, + default="geometric_within", + help="weighting scheme for SubTB", + ) + parser.add_argument( + "--subTB_lambda", type=float, default=0.9, help="Lambda parameter for SubTB" + ) + + parser.add_argument( + "--min_concentration", + type=float, + default=0.1, + help="minimal value for the Beta concentration parameters", + ) + + parser.add_argument( + "--max_concentration", + type=float, + default=5.1, + help="maximal value for the Beta concentration parameters", + ) + + parser.add_argument( + "--n_components", + type=int, + default=2, + help="Number of Beta distributions for P_F(s'|s)", + ) + parser.add_argument( + "--n_components_s0", + type=int, + default=4, + help="Number of Beta distributions for P_F(s'|s_0)", + ) + + parser.add_argument("--uniform_pb", action="store_true", help="Use a uniform PB") + parser.add_argument( + "--tied", + action="store_true", + help="Tie the parameters of PF, PB. F is never tied.", + ) + parser.add_argument( + "--hidden_dim", + type=int, + default=128, + help="Hidden dimension of the estimators' neural network modules.", + ) + parser.add_argument( + "--n_hidden", + type=int, + default=4, + help="Number of hidden layers (of size `hidden_dim`) in the estimators'" + + " neural network modules", + ) + + parser.add_argument( + "--lr", + type=float, + default=1e-3, + help="Learning rate for the estimators' modules", + ) + parser.add_argument( + "--lr_Z", + type=float, + default=1e-3, + help="Specific learning rate for logZ", + ) + parser.add_argument( + "--lr_F", + type=float, + default=1e-2, + help="Specific learning rate for the state flow function (only used for DB and SubTB losses)", + ) + parser.add_argument( + "--gamma_scheduler", + type=float, + default=0.5, + help="Every scheduler_milestone steps, multiply the learning rate by gamma_scheduler", + ) + parser.add_argument( + "--scheduler_milestone", + type=int, + default=2500, + help="Every scheduler_milestone steps, multiply the learning rate by gamma_scheduler", + ) + + parser.add_argument( + "--n_trajectories", + type=int, + default=int(3e6), + help="Total budget of trajectories to train on. " + + "Training iterations = n_trajectories // batch_size", + ) + + parser.add_argument( + "--validation_interval", + type=int, + default=500, + help="How often (in training steps) to validate the gflownet", + ) + parser.add_argument( + "--validation_samples", + type=int, + default=10000, + help="Number of validation samples to use to evaluate the probability mass function.", + ) + + parser.add_argument( + "--wandb_project", + type=str, + default="", + help="Name of the wandb project. If empty, don't use wandb", + ) + + args = parser.parse_args() + + print(main(args)) diff --git a/tutorials/examples/train_discreteebm.py b/tutorials/examples/train_discreteebm.py index 0dcd02a8..929b6e1f 100644 --- a/tutorials/examples/train_discreteebm.py +++ b/tutorials/examples/train_discreteebm.py @@ -23,6 +23,80 @@ from gfn.utils.common import validate from gfn.utils.modules import NeuralNet, Tabular + +def main(args): + seed = args.seed if args.seed != 0 else torch.randint(int(10e10), (1,))[0].item() + torch.manual_seed(seed) + + device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" + + use_wandb = len(args.wandb_project) > 0 + if use_wandb: + wandb.init(project=args.wandb_project) + wandb.config.update(args) + + # 1. Create the environment + env = DiscreteEBM(ndim=args.ndim, alpha=args.alpha) + + # 2. Create the gflownet. + # We need a LogEdgeFlowEstimator + if args.tabular: + module = Tabular(n_states=env.n_states, output_dim=env.n_actions) + else: + module = NeuralNet( + input_dim=env.preprocessor.output_dim, + output_dim=env.n_actions, + hidden_dim=args.hidden_dim, + n_hidden_layers=args.n_hidden, + ) + estimator = DiscretePolicyEstimator( + module=module, + n_actions=env.n_actions, + preprocessor=env.preprocessor, + ) + gflownet = FMGFlowNet(estimator) + + # 3. Create the optimizer + optimizer = torch.optim.Adam(module.parameters(), lr=args.lr) + + # 4. Train the gflownet + + visited_terminating_states = env.States.from_batch_shape((0,)) + + states_visited = 0 + n_iterations = args.n_trajectories // args.batch_size + validation_info = {"l1_dist": float("inf")} + for iteration in trange(n_iterations): + trajectories = gflownet.sample_trajectories(env, n_samples=args.batch_size) + training_samples = gflownet.to_training_samples(trajectories) + + optimizer.zero_grad() + loss = gflownet.loss(env, training_samples) + loss.backward() + optimizer.step() + + visited_terminating_states.extend(trajectories.last_states) + + states_visited += len(trajectories) + + to_log = {"loss": loss.item(), "states_visited": states_visited} + if use_wandb: + wandb.log(to_log, step=iteration) + if iteration % args.validation_interval == 0: + validation_info = validate( + env, + gflownet, + args.validation_samples, + visited_terminating_states, + ) + if use_wandb: + wandb.log(validation_info, step=iteration) + to_log.update(validation_info) + tqdm.write(f"{iteration}: {to_log}") + + return validation_info["l1_dist"] + + if __name__ == "__main__": parser = ArgumentParser() @@ -104,70 +178,4 @@ args = parser.parse_args() - seed = args.seed if args.seed != 0 else torch.randint(int(10e10), (1,))[0].item() - torch.manual_seed(seed) - - device_str = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu" - - use_wandb = len(args.wandb_project) > 0 - if use_wandb: - wandb.init(project=args.wandb_project) - wandb.config.update(args) - - # 1. Create the environment - env = DiscreteEBM(ndim=args.ndim, alpha=args.alpha) - - # 2. Create the gflownet. - # We need a LogEdgeFlowEstimator - if args.tabular: - module = Tabular(n_states=env.n_states, output_dim=env.n_actions) - else: - module = NeuralNet( - input_dim=env.preprocessor.output_dim, - output_dim=env.n_actions, - hidden_dim=args.hidden_dim, - n_hidden_layers=args.n_hidden, - ) - estimator = DiscretePolicyEstimator( - module=module, - n_actions=env.n_actions, - preprocessor=env.preprocessor, - ) - gflownet = FMGFlowNet(estimator) - - # 3. Create the optimizer - optimizer = torch.optim.Adam(module.parameters(), lr=args.lr) - - # 4. Train the gflownet - - visited_terminating_states = env.States.from_batch_shape((0,)) - - states_visited = 0 - n_iterations = args.n_trajectories // args.batch_size - for iteration in trange(n_iterations): - trajectories = gflownet.sample_trajectories(env, n_samples=args.batch_size) - training_samples = gflownet.to_training_samples(trajectories) - - optimizer.zero_grad() - loss = gflownet.loss(env, training_samples) - loss.backward() - optimizer.step() - - visited_terminating_states.extend(trajectories.last_states) - - states_visited += len(trajectories) - - to_log = {"loss": loss.item(), "states_visited": states_visited} - if use_wandb: - wandb.log(to_log, step=iteration) - if iteration % args.validation_interval == 0: - validation_info = validate( - env, - gflownet, - args.validation_samples, - visited_terminating_states, - ) - if use_wandb: - wandb.log(validation_info, step=iteration) - to_log.update(validation_info) - tqdm.write(f"{iteration}: {to_log}") + print(main(args)) diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 242296b5..79991492 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -31,123 +31,8 @@ from gfn.utils.common import validate from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular -if __name__ == "__main__": # noqa: C901 - parser = ArgumentParser() - - parser.add_argument("--no_cuda", action="store_true", help="Prevent CUDA usage") - - parser.add_argument( - "--ndim", type=int, default=2, help="Number of dimensions in the environment" - ) - parser.add_argument( - "--height", type=int, default=8, help="Height of the environment" - ) - parser.add_argument("--R0", type=float, default=0.1, help="Environment's R0") - parser.add_argument("--R1", type=float, default=0.5, help="Environment's R1") - parser.add_argument("--R2", type=float, default=2.0, help="Environment's R2") - - parser.add_argument( - "--seed", - type=int, - default=0, - help="Random seed, if 0 then a random seed is used", - ) - parser.add_argument( - "--batch_size", - type=int, - default=16, - help="Batch size, i.e. number of trajectories to sample per training iteration", - ) - parser.add_argument( - "--replay_buffer_size", - type=int, - default=0, - help="If zero, no replay buffer is used. Otherwise, the replay buffer is used.", - ) - - parser.add_argument( - "--loss", - type=str, - choices=["FM", "TB", "DB", "SubTB", "ZVar", "ModifiedDB"], - default="TB", - help="Loss function to use", - ) - parser.add_argument( - "--subTB_weighting", - type=str, - default="geometric_within", - help="weighting scheme for SubTB", - ) - parser.add_argument( - "--subTB_lambda", type=float, default=0.9, help="Lambda parameter for SubTB" - ) - - parser.add_argument( - "--tabular", - action="store_true", - help="Use a lookup table for F, PF, PB instead of an estimator", - ) - parser.add_argument("--uniform_pb", action="store_true", help="Use a uniform PB") - parser.add_argument( - "--tied", action="store_true", help="Tie the parameters of PF, PB, and F" - ) - parser.add_argument( - "--hidden_dim", - type=int, - default=256, - help="Hidden dimension of the estimators' neural network modules.", - ) - parser.add_argument( - "--n_hidden", - type=int, - default=2, - help="Number of hidden layers (of size `hidden_dim`) in the estimators'" - + " neural network modules", - ) - - parser.add_argument( - "--lr", - type=float, - default=1e-3, - help="Learning rate for the estimators' modules", - ) - parser.add_argument( - "--lr_Z", - type=float, - default=0.1, - help="Specific learning rate for Z (only used for TB loss)", - ) - - parser.add_argument( - "--n_trajectories", - type=int, - default=int(1e6), - help="Total budget of trajectories to train on. " - + "Training iterations = n_trajectories // batch_size", - ) - - parser.add_argument( - "--validation_interval", - type=int, - default=100, - help="How often (in training steps) to validate the gflownet", - ) - parser.add_argument( - "--validation_samples", - type=int, - default=200000, - help="Number of validation samples to use to evaluate the probability mass function.", - ) - - parser.add_argument( - "--wandb_project", - type=str, - default="", - help="Name of the wandb project. If empty, don't use wandb", - ) - - args = parser.parse_args() +def main(args): seed = args.seed if args.seed != 0 else torch.randint(int(10e10), (1,))[0].item() torch.manual_seed(seed) @@ -335,6 +220,7 @@ states_visited = 0 n_iterations = args.n_trajectories // args.batch_size + validation_info = {"l1_dist": float("inf")} for iteration in trange(n_iterations): trajectories = gflownet.sample_trajectories(env, n_samples=args.batch_size) training_samples = gflownet.to_training_samples(trajectories) @@ -368,3 +254,125 @@ wandb.log(validation_info, step=iteration) to_log.update(validation_info) tqdm.write(f"{iteration}: {to_log}") + + return validation_info["l1_dist"] + + +if __name__ == "__main__": # noqa: C901 + parser = ArgumentParser() + + parser.add_argument("--no_cuda", action="store_true", help="Prevent CUDA usage") + + parser.add_argument( + "--ndim", type=int, default=2, help="Number of dimensions in the environment" + ) + parser.add_argument( + "--height", type=int, default=8, help="Height of the environment" + ) + parser.add_argument("--R0", type=float, default=0.1, help="Environment's R0") + parser.add_argument("--R1", type=float, default=0.5, help="Environment's R1") + parser.add_argument("--R2", type=float, default=2.0, help="Environment's R2") + + parser.add_argument( + "--seed", + type=int, + default=0, + help="Random seed, if 0 then a random seed is used", + ) + parser.add_argument( + "--batch_size", + type=int, + default=16, + help="Batch size, i.e. number of trajectories to sample per training iteration", + ) + parser.add_argument( + "--replay_buffer_size", + type=int, + default=0, + help="If zero, no replay buffer is used. Otherwise, the replay buffer is used.", + ) + + parser.add_argument( + "--loss", + type=str, + choices=["FM", "TB", "DB", "SubTB", "ZVar", "ModifiedDB"], + default="TB", + help="Loss function to use", + ) + parser.add_argument( + "--subTB_weighting", + type=str, + default="geometric_within", + help="weighting scheme for SubTB", + ) + parser.add_argument( + "--subTB_lambda", type=float, default=0.9, help="Lambda parameter for SubTB" + ) + + parser.add_argument( + "--tabular", + action="store_true", + help="Use a lookup table for F, PF, PB instead of an estimator", + ) + parser.add_argument("--uniform_pb", action="store_true", help="Use a uniform PB") + parser.add_argument( + "--tied", action="store_true", help="Tie the parameters of PF, PB, and F" + ) + parser.add_argument( + "--hidden_dim", + type=int, + default=256, + help="Hidden dimension of the estimators' neural network modules.", + ) + parser.add_argument( + "--n_hidden", + type=int, + default=2, + help="Number of hidden layers (of size `hidden_dim`) in the estimators'" + + " neural network modules", + ) + + parser.add_argument( + "--lr", + type=float, + default=1e-3, + help="Learning rate for the estimators' modules", + ) + parser.add_argument( + "--lr_Z", + type=float, + default=0.1, + help="Specific learning rate for Z (only used for TB loss)", + ) + + parser.add_argument( + "--n_trajectories", + type=int, + default=int(1e6), + help="Total budget of trajectories to train on. " + + "Training iterations = n_trajectories // batch_size", + ) + + parser.add_argument( + "--validation_interval", + type=int, + default=100, + help="How often (in training steps) to validate the gflownet", + ) + parser.add_argument( + "--validation_samples", + type=int, + default=200000, + help="Number of validation samples to use to evaluate the probability mass function.", + ) + + parser.add_argument( + "--wandb_project", + type=str, + default="", + help="Name of the wandb project. If empty, don't use wandb", + ) + + args = parser.parse_args() + + print(main(args)) From f5ac047ead7da6862d2b25082e4b03f75e68b77c Mon Sep 17 00:00:00 2001 From: Salem Date: Sun, 6 Aug 2023 16:48:15 -0400 Subject: [PATCH 10/21] fix args for box tests --- tutorials/examples/test_scripts.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index cf16e319..f1313ae3 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -92,6 +92,7 @@ def test_discreteebm(ndim: int, alpha: float): def test_box(delta: float, loss: str): n_trajectories = 128128 validation_interval = 500 + validation_samples = 10000 args = BoxArgs( delta=delta, loss=loss, @@ -99,8 +100,11 @@ def test_box(delta: float, loss: str): hidden_dim=128, n_hidden=4, batch_size=128, + lr_Z=1e-3, validation_interval=validation_interval, + validation_samples=validation_samples, ) + print(args) final_jsd = train_box_main(args) if loss == "TB" and delta == 0.1: assert final_jsd < 7e-2 From 1e796b16707314d51e53ba45ef262d7f3dcf5eca Mon Sep 17 00:00:00 2001 From: Salem Date: Sun, 6 Aug 2023 16:59:04 -0400 Subject: [PATCH 11/21] fix seed for script tests --- tutorials/examples/test_scripts.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index f1313ae3..f1c07baa 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -9,7 +9,7 @@ @dataclass class CommonArgs: no_cuda: bool = True - seed: int = 0 + seed: int = 1 # We fix the seed for reproducibility batch_size: int = 16 replay_buffer_size: int = 0 loss: str = "TB" @@ -62,13 +62,13 @@ def test_hypergrid(ndim: int, height: int): args = HypergridArgs(ndim=ndim, height=height, n_trajectories=n_trajectories) final_l1_dist = train_hypergrid_main(args) if ndim == 2 and height == 8: - assert final_l1_dist < 1e-3 + assert final_l1_dist < 7.3e-4 elif ndim == 2 and height == 16: - assert final_l1_dist < 6e-4 + assert final_l1_dist < 4.8e-4 elif ndim == 4 and height == 8: - assert final_l1_dist < 2e-4 + assert final_l1_dist < 1.6e-4 elif ndim == 4 and height == 16: - assert final_l1_dist < 3e-5 + assert final_l1_dist < 2.45e-5 @pytest.mark.parametrize("ndim", [2, 4]) @@ -78,13 +78,13 @@ def test_discreteebm(ndim: int, alpha: float): args = DiscreteEBMArgs(ndim=ndim, alpha=alpha, n_trajectories=n_trajectories) final_l1_dist = train_discreteebm_main(args) if ndim == 2 and alpha == 0.1: - assert final_l1_dist < 6e-3 + assert final_l1_dist < 0.0026 elif ndim == 2 and alpha == 1.0: - assert final_l1_dist < 3e-2 + assert final_l1_dist < 0.017 elif ndim == 4 and alpha == 0.1: - assert final_l1_dist < 9e-3 + assert final_l1_dist < 0.009 elif ndim == 4 and alpha == 1.0: - assert final_l1_dist < 7e-2 + assert final_l1_dist < 0.062 @pytest.mark.parametrize("delta", [0.1, 0.25]) @@ -107,10 +107,10 @@ def test_box(delta: float, loss: str): print(args) final_jsd = train_box_main(args) if loss == "TB" and delta == 0.1: - assert final_jsd < 7e-2 + assert final_jsd < 0.046 elif loss == "DB" and delta == 0.1: - assert final_jsd < 0.2 + assert final_jsd < 0.18 if loss == "TB" and delta == 0.25: - assert final_jsd < 1e-3 + assert final_jsd < 0.015 elif loss == "DB" and delta == 0.25: - assert final_jsd < 4e-2 + assert final_jsd < 0.027 From e1ea13e0ea8781b4de1e14941602b363c07afad8 Mon Sep 17 00:00:00 2001 From: Salem Date: Sun, 6 Aug 2023 18:47:51 -0400 Subject: [PATCH 12/21] pre-commit stuff --- .pre-commit-config.yaml | 2 ++ src/gfn/gym/helpers/test_box_utils.py | 2 +- tutorials/examples/test_scripts.py | 1 + tutorials/examples/train_box.py | 4 ++-- tutorials/examples/train_discreteebm.py | 4 ++-- tutorials/examples/train_hypergrid.py | 4 ++-- 6 files changed, 10 insertions(+), 7 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 75e15a9f..1aa827ce 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,6 +45,8 @@ repos: name: pytest-check entry: pytest language: python + # The next line excludes the file test_scripts.py + exclude: ^test_scripts.py$ pass_filenames: false types: [python] always_run: true diff --git a/src/gfn/gym/helpers/test_box_utils.py b/src/gfn/gym/helpers/test_box_utils.py index aefd12e2..fb004140 100644 --- a/src/gfn/gym/helpers/test_box_utils.py +++ b/src/gfn/gym/helpers/test_box_utils.py @@ -1,5 +1,5 @@ -import torch import pytest +import torch from gfn.gym import Box from gfn.gym.helpers.box_utils import ( diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index f1c07baa..71afa8f3 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -1,4 +1,5 @@ from dataclasses import dataclass + import pytest from .train_box import main as train_box_main diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index 6d32f254..d8e879bd 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -82,7 +82,7 @@ def estimate_jsd(kde1, kde2): return jsd / 2.0 -def main(args): +def main(args): # noqa: C901 seed = args.seed if args.seed != 0 else torch.randint(int(10e10), (1,))[0].item() torch.manual_seed(seed) @@ -275,7 +275,7 @@ def main(args): return jsd -if __name__ == "__main__": # noqa: C901 +if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--no_cuda", action="store_true", help="Prevent CUDA usage") diff --git a/tutorials/examples/train_discreteebm.py b/tutorials/examples/train_discreteebm.py index 929b6e1f..a7aab784 100644 --- a/tutorials/examples/train_discreteebm.py +++ b/tutorials/examples/train_discreteebm.py @@ -24,7 +24,7 @@ from gfn.utils.modules import NeuralNet, Tabular -def main(args): +def main(args): # noqa: C901 seed = args.seed if args.seed != 0 else torch.randint(int(10e10), (1,))[0].item() torch.manual_seed(seed) @@ -36,7 +36,7 @@ def main(args): wandb.config.update(args) # 1. Create the environment - env = DiscreteEBM(ndim=args.ndim, alpha=args.alpha) + env = DiscreteEBM(ndim=args.ndim, alpha=args.alpha, device_str=device_str) # 2. Create the gflownet. # We need a LogEdgeFlowEstimator diff --git a/tutorials/examples/train_hypergrid.py b/tutorials/examples/train_hypergrid.py index 79991492..e9fd465c 100644 --- a/tutorials/examples/train_hypergrid.py +++ b/tutorials/examples/train_hypergrid.py @@ -32,7 +32,7 @@ from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular -def main(args): +def main(args): # noqa: C901 seed = args.seed if args.seed != 0 else torch.randint(int(10e10), (1,))[0].item() torch.manual_seed(seed) @@ -258,7 +258,7 @@ def main(args): return validation_info["l1_dist"] -if __name__ == "__main__": # noqa: C901 +if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--no_cuda", action="store_true", help="Prevent CUDA usage") From 9913493693898d82ca7dd721d45f979162c38106 Mon Sep 17 00:00:00 2001 From: Salem Date: Sun, 6 Aug 2023 19:13:34 -0400 Subject: [PATCH 13/21] only quick tests during precommit --- .pre-commit-config.yaml | 4 +--- README.md | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1aa827ce..94f5d6a9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,8 +45,6 @@ repos: name: pytest-check entry: pytest language: python - # The next line excludes the file test_scripts.py - exclude: ^test_scripts.py$ - pass_filenames: false + files: testing/ types: [python] always_run: true diff --git a/README.md b/README.md index 44f3f0b9..bf9d1a35 100644 --- a/README.md +++ b/README.md @@ -116,7 +116,7 @@ pre-commit install pre-commit run --all-files ``` -Run `pre-commit` after staging, and before committing. Make sure all the tests pass (By running `pytest`). +Run `pre-commit` after staging, and before committing. Make sure all the tests pass (By running `pytest`). Note that the `pytest` hook of `pre-commit` only runs the tests in the `testing/` folder. To run all the tests, which take longer, run `pytest` manually. The codebase uses `black` formatter. To make the docs locally: From bfd1c8e4e90d3bf618f08d73fa91915335068dea Mon Sep 17 00:00:00 2001 From: Salem Date: Sun, 6 Aug 2023 19:24:07 -0400 Subject: [PATCH 14/21] fix n_components --- tutorials/examples/test_scripts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index 71afa8f3..b2d0656a 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -50,7 +50,7 @@ class BoxArgs(CommonArgs): min_concentration: float = 0.1 max_concentration: float = 5.1 n_components: int = 2 - n_components_s0: int = 2 + n_components_s0: int = 4 gamma_scheduler: float = 0.5 scheduler_milestone: int = 2500 lr_F: float = 1e-2 From 53edcd79fb23fd8e5efff77d0a40d65f2ccbe4b2 Mon Sep 17 00:00:00 2001 From: Salem Date: Mon, 7 Aug 2023 18:30:26 -0400 Subject: [PATCH 15/21] we do not clip logZ grad --- tutorials/examples/train_box.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tutorials/examples/train_box.py b/tutorials/examples/train_box.py index d8e879bd..996f4c1f 100644 --- a/tutorials/examples/train_box.py +++ b/tutorials/examples/train_box.py @@ -238,7 +238,7 @@ def main(args): # noqa: C901 loss.backward() for p in gflownet.parameters(): - if p.grad is not None: + if p.ndim > 0 and p.grad is not None: # We do not clip logZ grad p.grad.data.clamp_(-10, 10).nan_to_num_(0.0) optimizer.step() scheduler.step() From 41f62135a16aeaf40b4d6336ed03ca22274e886a Mon Sep 17 00:00:00 2001 From: Salem Date: Mon, 7 Aug 2023 18:35:05 -0400 Subject: [PATCH 16/21] add explanation to script --- tutorials/examples/test_scripts.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tutorials/examples/test_scripts.py b/tutorials/examples/test_scripts.py index b2d0656a..0f63316a 100644 --- a/tutorials/examples/test_scripts.py +++ b/tutorials/examples/test_scripts.py @@ -1,3 +1,8 @@ +# This file includes tests for the three examples in the tutorials folder. +# The tests ensure that after a certain number of iterations, the final L1 distance +# or JSD between the learned distribution and the target distribution is below a +# certain threshold. + from dataclasses import dataclass import pytest From 834bf5f82815966db6112c66b3ac3785e341cfd4 Mon Sep 17 00:00:00 2001 From: Salem Date: Tue, 8 Aug 2023 15:01:34 -0400 Subject: [PATCH 17/21] remove named_parameters --- src/gfn/utils/modules.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/gfn/utils/modules.py b/src/gfn/utils/modules.py index 92524a42..a1f43ef1 100644 --- a/src/gfn/utils/modules.py +++ b/src/gfn/utils/modules.py @@ -139,6 +139,3 @@ def forward( preprocessed_states.device ) return out - - def named_parameters(self) -> Iterator[Tuple[str, Parameter]]: - return iter([]) From f1a519ce69f80d00a5ebf4127b77740e834c2679 Mon Sep 17 00:00:00 2001 From: Salem Date: Tue, 8 Aug 2023 15:04:33 -0400 Subject: [PATCH 18/21] update the README --- README.md | 57 +++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index bf9d1a35..dac4f537 100644 --- a/README.md +++ b/README.md @@ -48,9 +48,13 @@ pip install . ## About this repo -This repo serves the purpose of fast prototyping [GFlowNet](https://arxiv.org/abs/2111.09266) related algorithms. It decouples the environment definition, the sampling process, and the parametrization of the function approximators used to calculate the GFN loss. +This repo serves the purpose of fast prototyping [GFlowNet](https://arxiv.org/abs/2111.09266) (GFN) related algorithms. It decouples the environment definition, the sampling process, and the parametrization of the function approximators used to calculate the GFN loss. It aims to accompany researchers and engineers in learning about GFlowNets, and in developing new algorithms. -Example scripts and notebooks are provided [here](https://github.com/saleml/torchgfn/tree/master/tutorials/). +Currently, the library is shipped with three environments: two discrete environments (Discrete Energy Based Model and Hyper Grid) and a continuous box environment. The library is designed to allow users to define their own environments. See [here](https://github.com/saleml/torchgfn/tree/master/tutorials/ENV.md) for more details. + +### Scripts and notebooks + +Example scripts and notebooks for the three environments are provided [here](https://github.com/saleml/torchgfn/tree/master/tutorials/examples). For the hyper grid and the box environments, the provided scripts are supposed to reproduce published results. ### Standalone example @@ -61,32 +65,43 @@ This example, which shows how to use the library for a simple discrete environme import torch from tqdm import tqdm -from gfn.gflownet import TBGFlowNet -from gfn.gym import HyperGrid +from gfn.gflownet import TBGFlowNet # We use a GFlowNet with the Trajectory Balance (TB) loss +from gfn.gym import HyperGrid # We use the hyper grid environment from gfn.modules import DiscretePolicyEstimator from gfn.samplers import Sampler -from gfn.utils import NeuralNet +from gfn.utils import NeuralNet # NeuralNet is a simple multi-layer perceptron (MLP) if __name__ == "__main__": + # 1 - We define the environment + env = HyperGrid(ndim=4, height=8, R0=0.01) # Grid of size 8x8x8x8 + # 2 - We define the needed modules (neural networks) + + # The environment has a preprocessor attribute, which is used to preprocess the state before feeding it to the policy estimator module_PF = NeuralNet( input_dim=env.preprocessor.output_dim, output_dim=env.n_actions - ) + ) # Neural network for the forward policy, with as many outputs as there are actions module_PB = NeuralNet( input_dim=env.preprocessor.output_dim, output_dim=env.n_actions - 1, - torso=module_PF.torso + torso=module_PF.torso # We share all the parameters of P_F and P_B, except for the last layer ) + # 3 - We define the estimators + pf_estimator = DiscretePolicyEstimator(module_PF, env.n_actions, is_backward=False, preprocessor=env.preprocessor) pb_estimator = DiscretePolicyEstimator(module_PB, env.n_actions, is_backward=True, preprocessor=env.preprocessor) - gfn = TBGFlowNet(init_logZ=0., pf=pf_estimator, pb=pb_estimator) + # 4 - We define the GFlowNet - sampler = Sampler(estimator=pf_estimator) + gfn = TBGFlowNet(init_logZ=0., pf=pf_estimator, pb=pb_estimator) # We initialize logZ to 0 + + # 5 - We define the sampler and the optimizer + + sampler = Sampler(estimator=pf_estimator) # We use an on-policy sampler, based on the forward policy # Policy parameters have their own LR. non_logz_params = [v for k, v in dict(gfn.named_parameters()).items() if k != "logZ"] @@ -94,7 +109,9 @@ if __name__ == "__main__": # Log Z gets dedicated learning rate (typically higher). logz_params = [dict(gfn.named_parameters())["logZ"]] - optimizer.add_param_group({"params": logz_params, "lr": 1e-2}) + optimizer.add_param_group({"params": logz_params, "lr": 1e-1}) + + # 6 - We train the GFlowNet for 1000 iterations, with 16 trajectories per iteration for i in (pbar := tqdm(range(1000))): trajectories = sampler.sample_trajectories(env=env, n_trajectories=16) @@ -145,7 +162,7 @@ The `batch_shape` attribute is required to keep track of the batch dimension. A Because multiple trajectories can have different lengths, batching requires appending a dummy tensor to trajectories that are shorter than the longest trajectory. The dummy state is the $s_f$ attribute of the environment (e.g. `[-1, ..., -1]`, or `[-inf, ..., -inf]`, etc...). Which is never processed, and is used to pad the batch of states only. -For discrete environments, the action set is represented with the set $\{0, \dots, n_{actions} - 1\}$, where the $(n_{actions})$-th action always corresponds to the exit or terminate action, i.e. that results in a transition of the type $s \rightarrow s_f$, but not all actions are possible at all states. Each `States` object is endowed with two extra attributes: `forward_masks` and `backward_masks`, representing which actions are allowed at each state and which actions could have led to each state, respectively. Such states are instances of the `DiscreteStates` abstract subclass of `States`. The `forward_masks` tensor is of shape `(*batch_shape, n_{actions})`, and `backward_masks` is of shape `(*batch_shape, n_{actions} - 1)`. Each subclass of `DiscreteStates` needs to implement the `update_masks` function, that uses the environment's logic to define the two tensors. +For discrete environments, the action set is represented with the set $\{0, \dots, n_{actions} - 1\}$, where the $(n_{actions})$-th action always corresponds to the exit or terminate action, i.e. that results in a transition of the type $s \rightarrow s_f$, but not all actions are possible at all states. For discrete environments, each `States` object is endowed with two extra attributes: `forward_masks` and `backward_masks`, representing which actions are allowed at each state and which actions could have led to each state, respectively. Such states are instances of the `DiscreteStates` abstract subclass of `States`. The `forward_masks` tensor is of shape `(*batch_shape, n_{actions})`, and `backward_masks` is of shape `(*batch_shape, n_{actions} - 1)`. Each subclass of `DiscreteStates` needs to implement the `update_masks` function, that uses the environment's logic to define the two tensors. ### Actions Actions should be though of as internal actions of an agent building a compositional object. They correspond to transitions $s \rightarrow s'$. An abstract `Actions` class is provided. It is automatically subclassed for discrete environments, but needs to be manually subclassed otherwise. @@ -163,32 +180,33 @@ Containers are collections of `States`, along with other information, such as re - [Transitions](https://github.com/saleml/torchgfn/tree/master/src/gfn/containers/transitions.py), representing a batch of transitions $s \rightarrow s'$. - [Trajectories](https://github.com/saleml/torchgfn/tree/master/src/gfn/containers/trajectories.py), representing a batch of complete trajectories $\tau = s_0 \rightarrow s_1 \rightarrow \dots \rightarrow s_n \rightarrow s_f$. -These containers can either be instantiated using a `States` object, or can be initialized as empty containers that can be populated on the fly, allowing the usage of the[ReplayBuffer](https://github.com/saleml/torchgfn/tree/master/src/gfn/containers/replay_buffer.py) class. +These containers can either be instantiated using a `States` object, or can be initialized as empty containers that can be populated on the fly, allowing the usage of the [ReplayBuffer](https://github.com/saleml/torchgfn/tree/master/src/gfn/containers/replay_buffer.py) class. They inherit from the base `Container` [class](https://github.com/saleml/torchgfn/tree/master/src/gfn/containers/base.py), indicating some helpful methods. -In most cases, one needs to sample complete trajectories. From a batch of trajectories, a batch of states and batch of transitions can be defined using `Trajectories.to_transitions()` and `Trajectories.to_states()`. These exclude meaningless transitions and dummy states that were added to the batch of trajectories to allow for efficient batching. +In most cases, one needs to sample complete trajectories. From a batch of trajectories, a batch of states and batch of transitions can be defined using `Trajectories.to_transitions()` and `Trajectories.to_states()`, in order to train GFlowNets with losses that are edge-decomposable or state-decomposable. These exclude meaningless transitions and dummy states that were added to the batch of trajectories to allow for efficient batching. ### Modules Training GFlowNets requires one or multiple estimators, called `GFNModule`s, which is an abstract subclass of `torch.nn.Module`. In addition to the usual `forward` function, `GFNModule`s need to implement a `required_output_dim` attribute, to ensure that the outputs have the required dimension for the task at hand; and some (but not all) need to implement a `to_probability_distribution` function. -- `DiscretePolicyEstimator` is a `GFNModule` that defines the policies $P_F(. \mid s)$ and $P_B(. \mid s)$ for discrete environments. When `is_backward=False`, the required output dimension is `n = env.n_actions`, and when `is_backward=True`, it is `n = env.n_actions - 1`. These `n` numbers represent the logits of a Categorical distribution. Additionally, they include exploration parameters, in order to define a tempered version of $P_F$, or a mixture of $P_F$ with a uniform distribution. Naturally, before defining the Categorical distributions, forbidden actions (that are encoded in the `DiscreteStates`' masks attributes), are given 0 probability by setting the corresponding logit to $-\infty$. + +- `DiscretePolicyEstimator` is a `GFNModule` that defines the policies $P_F(. \mid s)$ and $P_B(. \mid s)$ for discrete environments. When `is_backward=False`, the required output dimension is `n = env.n_actions`, and when `is_backward=True`, it is `n = env.n_actions - 1`. These `n` numbers represent the logits of a Categorical distribution. The corresponding `to_probability_distribution` function transforms the logits by masking illegal actions (according to the forward or backward masks), then return a Categorical distribution. The masking is done by setting the corresponding logit to $-\infty$. The function also includes exploration parameters, in order to define a tempered version of $P_F$, or a mixture of $P_F$ with a uniform distribution. - `ScalarModule` is a simple module with required output dimension 1. It is useful to define log-state flows $\log F(s)$. For non-discrete environments, the user needs to specify their own policies $P_F$ and $P_B$. The module, taking as input a batch of states (as a `States`) object, should return the batched parameters of a `torch.Distribution`. The distribution depends on the environment. The `to_probability_distribution` function handles the conversion of the parameter outputs to an actual batched `Distribution` object, that implements at least the `sample` and `log_prob` functions. An example is provided [here](https://github.com/saleml/torchgfn/tree/master/src/gfn/gym/helpers/box_utils.py), for a square environment in which the forward policy has support either on a quarter disk, or on an arc-circle, such that the angle, and the radius (for the quarter disk part) are scaled samples from a mixture of Beta distributions. The provided example shows an intricate scenario, and it is not expected that user defined environment need this much level of details. -In all `GFNModule`s, note that the input of the `forward` function is a `States` object. Meaning that they first need to be transformed to tensors. However, `states.tensor` does not necessarily include the structure that a neural network can used to generalize. It is common in these scenarios to have a function that transforms these raw tensor states to ones where the structure is clearer, via a `Preprocessor` object, that is part of the environment. More on this [here](https://github.com/saleml/torchgfn/tree/master/tutorials/ENV.md). The default preprocessor of an environment is the identity preprocessor. The `forward` pass thus first calls the `preprocessor` attribute of the environment on `States`, before performing any transformation. +In all `GFNModule`s, note that the input of the `forward` function is a `States` object. Meaning that they first need to be transformed to tensors. However, `states.tensor` does not necessarily include the structure that a neural network can used to generalize. It is common in these scenarios to have a function that transforms these raw tensor states to ones where the structure is clearer, via a `Preprocessor` object, that is part of the environment. More on this [here](https://github.com/saleml/torchgfn/tree/master/tutorials/ENV.md). The default preprocessor of an environment is the identity preprocessor. The `forward` pass thus first calls the `preprocessor` attribute of the environment on `States`, before performing any transformation. The `preprocessor` is thus an attribute of the module. If it is not explicitly defined, it is set to the identity preprocessor. -For discrete environments, tabular modules are provided, where a lookup table is used instead of a neural network. Additionally, a `UniformPB` module is provided, implementing a uniform backward policy. +For discrete environments, a `Tabular` module is provided, where a lookup table is used instead of a neural network. Additionally, a `UniformPB` module is provided, implementing a uniform backward policy. These modules are provided [here](https://github.com/saleml/torchgfn/tree/master/src/gfn/utils/modules.py). ### Samplers -A [Sampler](https://github.com/saleml/torchgfn/tree/master/src/gfn/samplers.py) object defines how actions are sampled (`sample_actions()`) at each state, and trajectories (`sample_trajectories()`), which can sample a batch of trajectories starting from a given set of initial states or starting from $s_0$. It requires a `GFNModule` that implements the `to_probability_distribution` function. +A [Sampler](https://github.com/saleml/torchgfn/tree/master/src/gfn/samplers.py) object defines how actions are sampled (`sample_actions()`) at each state, and trajectories (`sample_trajectories()`), which can sample a batch of trajectories starting from a given set of initial states or starting from $s_0$. It requires a `GFNModule` that implements the `to_probability_distribution` function. For off-policy sampling, the parameters of `to_probability_distribution` can be directly passed when initializing the `Sampler`. ### Losses -GFlowNets can be trained with different losses, each of which requires a different parametrization, which we call in this library a `GFlowNet`. A `GFlowNet` is a `GFNModule` that includes one or multiple `GFNModules`, at least one of which implements a `to_probability_distribution` function. They also need to implement a `loss` function, that takes as input either states, transitions, or trajectories, depending on the loss. +GFlowNets can be trained with different losses, each of which requires a different parametrization, which we call in this library a `GFlowNet`. A `GFlowNet` is a `GFNModule` that includes one or multiple `GFNModule`s, at least one of which implements a `to_probability_distribution` function. They also need to implement a `loss` function, that takes as input either states, transitions, or trajectories, depending on the loss. Currently, the implemented losses are: @@ -197,6 +215,3 @@ Currently, the implemented losses are: - Trajectory Balance - Sub-Trajectory Balance. By default, each sub-trajectory is weighted geometrically (within the trajectory) depending on its length. This corresponds to the strategy defined [here](https://www.semanticscholar.org/reader/f2c32fe3f7f3e2e9d36d833e32ec55fc93f900f5). Other strategies exist and are implemented [here](https://github.com/saleml/torchgfn/tree/master/src/gfn/losses/sub_trajectory_balance.py). - Log Partition Variance loss. Introduced [here](https://arxiv.org/abs/2302.05446) - -# Scripts -Example scripts are provided [here](https://github.com/saleml/torchgfn/tree/master/tutorials/examples/). They can be used to reproduce published results in the HyperGrid environment, and the Box environment. \ No newline at end of file From 1cf26292b7a9e0427be636f09f24b0f478a8bc28 Mon Sep 17 00:00:00 2001 From: Salem Date: Tue, 8 Aug 2023 15:05:11 -0400 Subject: [PATCH 19/21] autoflake doing autoflake stuff --- src/gfn/utils/modules.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gfn/utils/modules.py b/src/gfn/utils/modules.py index a1f43ef1..7379d276 100644 --- a/src/gfn/utils/modules.py +++ b/src/gfn/utils/modules.py @@ -1,6 +1,6 @@ """This file contains some examples of modules that can be used with GFN.""" -from typing import Iterator, Literal, Optional, Tuple +from typing import Literal, Optional import torch import torch.nn as nn From d9049816ea83757d58a890659a0e2f39d5d37b4f Mon Sep 17 00:00:00 2001 From: Salem Date: Mon, 14 Aug 2023 18:37:59 -0400 Subject: [PATCH 20/21] Be more explicit in README about DiscretePolicyEstimator --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index dac4f537..81ee0b60 100644 --- a/README.md +++ b/README.md @@ -190,7 +190,7 @@ In most cases, one needs to sample complete trajectories. From a batch of trajec Training GFlowNets requires one or multiple estimators, called `GFNModule`s, which is an abstract subclass of `torch.nn.Module`. In addition to the usual `forward` function, `GFNModule`s need to implement a `required_output_dim` attribute, to ensure that the outputs have the required dimension for the task at hand; and some (but not all) need to implement a `to_probability_distribution` function. -- `DiscretePolicyEstimator` is a `GFNModule` that defines the policies $P_F(. \mid s)$ and $P_B(. \mid s)$ for discrete environments. When `is_backward=False`, the required output dimension is `n = env.n_actions`, and when `is_backward=True`, it is `n = env.n_actions - 1`. These `n` numbers represent the logits of a Categorical distribution. The corresponding `to_probability_distribution` function transforms the logits by masking illegal actions (according to the forward or backward masks), then return a Categorical distribution. The masking is done by setting the corresponding logit to $-\infty$. The function also includes exploration parameters, in order to define a tempered version of $P_F$, or a mixture of $P_F$ with a uniform distribution. +- `DiscretePolicyEstimator` is a `GFNModule` that defines the policies $P_F(. \mid s)$ and $P_B(. \mid s)$ for discrete environments. When `is_backward=False`, the required output dimension is `n = env.n_actions`, and when `is_backward=True`, it is `n = env.n_actions - 1`. These `n` numbers represent the logits of a Categorical distribution. The corresponding `to_probability_distribution` function transforms the logits by masking illegal actions (according to the forward or backward masks), then return a Categorical distribution. The masking is done by setting the corresponding logit to $-\infty$. The function also includes exploration parameters, in order to define a tempered version of $P_F$, or a mixture of $P_F$ with a uniform distribution. `DiscretePolicyEstimator`` with `is_backward=False`` can be used to represent log-edge-flow estimators $\log F(s \rightarrow s')$. - `ScalarModule` is a simple module with required output dimension 1. It is useful to define log-state flows $\log F(s)$. For non-discrete environments, the user needs to specify their own policies $P_F$ and $P_B$. The module, taking as input a batch of states (as a `States`) object, should return the batched parameters of a `torch.Distribution`. The distribution depends on the environment. The `to_probability_distribution` function handles the conversion of the parameter outputs to an actual batched `Distribution` object, that implements at least the `sample` and `log_prob` functions. An example is provided [here](https://github.com/saleml/torchgfn/tree/master/src/gfn/gym/helpers/box_utils.py), for a square environment in which the forward policy has support either on a quarter disk, or on an arc-circle, such that the angle, and the radius (for the quarter disk part) are scaled samples from a mixture of Beta distributions. The provided example shows an intricate scenario, and it is not expected that user defined environment need this much level of details. From cba92901b19e5a0b2e7f4c9e0cb85ea2b10214a4 Mon Sep 17 00:00:00 2001 From: Salem Date: Mon, 14 Aug 2023 18:38:05 -0400 Subject: [PATCH 21/21] up the version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1d584a8e..269e7d1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "torchgfn" packages = [{include = "gfn", from = "src"}] -version = "1.0.0" +version = "1.0.1" description = "A torch implementation of GFlowNets" authors = ["Salem Lahou ", "Joseph Viviano ", "Victor Schmidt "] license = "MIT"