diff --git a/src/gfn/actions.py b/src/gfn/actions.py index 1d579ae7..2006b018 100644 --- a/src/gfn/actions.py +++ b/src/gfn/actions.py @@ -32,12 +32,12 @@ def __init__(self, tensor: torch.Tensor): Args: tensor: tensors representing a batch of actions with shape (*batch_shape, *action_shape). """ - assert tensor.shape[-len(self.action_shape):] == self.action_shape, ( - f"Batched actions tensor has shape {tensor.shape}, but the expected action shape is {self.action_shape}." - ) - + assert ( + tensor.shape[-len(self.action_shape) :] == self.action_shape + ), f"Batched actions tensor has shape {tensor.shape}, but the expected action shape is {self.action_shape}." + self.tensor = tensor - self.batch_shape = tuple(self.tensor.shape)[:-len(self.action_shape)] + self.batch_shape = tuple(self.tensor.shape)[: -len(self.action_shape)] @classmethod def make_dummy_actions(cls, batch_shape: tuple[int]) -> Actions: @@ -137,13 +137,13 @@ def compare(self, other: torch.Tensor) -> torch.Tensor: Args: other: tensor of actions to compare, with shape (*batch_shape, *action_shape). - + Returns: boolean tensor of shape batch_shape indicating whether the actions are equal. """ - assert other.shape == self.batch_shape + self.action_shape, ( - f"Expected shape {self.batch_shape + self.action_shape}, got {other.shape}." - ) + assert ( + other.shape == self.batch_shape + self.action_shape + ), f"Expected shape {self.batch_shape + self.action_shape}, got {other.shape}." out = self.tensor == other n_batch_dims = len(self.batch_shape) diff --git a/src/gfn/containers/trajectories.py b/src/gfn/containers/trajectories.py index 5de2d494..f9e8d87f 100644 --- a/src/gfn/containers/trajectories.py +++ b/src/gfn/containers/trajectories.py @@ -1,7 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Sequence, Union, Tuple - +from typing import TYPE_CHECKING, Sequence, Tuple, Union if TYPE_CHECKING: from gfn.actions import Actions @@ -92,20 +91,29 @@ def __init__( if when_is_done is not None else torch.full(size=(0,), fill_value=-1, dtype=torch.long) ) - assert self.when_is_done.shape == (self.n_trajectories,) and self.when_is_done.dtype == torch.long + assert ( + self.when_is_done.shape == (self.n_trajectories,) + and self.when_is_done.dtype == torch.long + ) self._log_rewards = ( log_rewards if log_rewards is not None else torch.full(size=(0,), fill_value=0, dtype=torch.float) ) - assert self._log_rewards.shape == (self.n_trajectories,) and self._log_rewards.dtype == torch.float + assert ( + self._log_rewards.shape == (self.n_trajectories,) + and self._log_rewards.dtype == torch.float + ) if log_probs is not None: - assert log_probs.shape == (self.max_length, self.n_trajectories) and log_probs.dtype == torch.float + assert ( + log_probs.shape == (self.max_length, self.n_trajectories) + and log_probs.dtype == torch.float + ) else: log_probs = torch.full(size=(0, 0), fill_value=0, dtype=torch.float) - self.log_probs = log_probs + self.log_probs = log_probs self.estimator_outputs = estimator_outputs if self.estimator_outputs is not None: @@ -207,15 +215,13 @@ def __getitem__(self, index: int | Sequence[int]) -> Trajectories: ) @staticmethod - def extend_log_probs( - log_probs: torch.Tensor, new_max_length: int - ) -> torch.Tensor: + def extend_log_probs(log_probs: torch.Tensor, new_max_length: int) -> torch.Tensor: """Extend the log_probs matrix by adding 0 until the required length is reached. - + Args: log_probs: The log_probs tensor of shape (max_length, n_trajectories) to extend. new_max_length: The new length of the log_probs tensor. - + Returns: The extended log_probs tensor of shape (new_max_length, n_trajectories). """ diff --git a/src/gfn/containers/transitions.py b/src/gfn/containers/transitions.py index 8ae945c9..3c03a53d 100644 --- a/src/gfn/containers/transitions.py +++ b/src/gfn/containers/transitions.py @@ -84,7 +84,10 @@ def __init__( if is_done is not None else torch.full(size=(0,), fill_value=False, dtype=torch.bool) ) - assert self.is_done.shape == (self.n_transitions,) and self.is_done.dtype == torch.bool + assert ( + self.is_done.shape == (self.n_transitions,) + and self.is_done.dtype == torch.bool + ) self.next_states = ( next_states @@ -96,9 +99,15 @@ def __init__( and self.states.batch_shape == self.next_states.batch_shape ) self._log_rewards = log_rewards if log_rewards is not None else torch.zeros(0) - assert self._log_rewards.shape == (self.n_transitions,) and self._log_rewards.dtype == torch.float + assert ( + self._log_rewards.shape == (self.n_transitions,) + and self._log_rewards.dtype == torch.float + ) self.log_probs = log_probs if log_probs is not None else torch.zeros(0) - assert self.log_probs.shape == (self.n_transitions,) and self.log_probs.dtype == torch.float + assert ( + self.log_probs.shape == (self.n_transitions,) + and self.log_probs.dtype == torch.float + ) @property def n_transitions(self) -> int: @@ -186,8 +195,11 @@ def all_log_rewards(self) -> torch.Tensor: log_rewards[~is_sink_state, 1] = torch.log( self.env.reward(self.next_states[~is_sink_state]) ) - - assert log_rewards.shape == (self.n_transitions, 2) and log_rewards.dtype == torch.float + + assert ( + log_rewards.shape == (self.n_transitions, 2) + and log_rewards.dtype == torch.float + ) return log_rewards def __getitem__(self, index: int | Sequence[int]) -> Transitions: diff --git a/src/gfn/gflownet/base.py b/src/gfn/gflownet/base.py index 8e3fd4b5..321306e0 100644 --- a/src/gfn/gflownet/base.py +++ b/src/gfn/gflownet/base.py @@ -1,6 +1,6 @@ import math from abc import ABC, abstractmethod -from typing import Generic, Tuple, TypeVar, Union, Any +from typing import Any, Generic, Tuple, TypeVar, Union import torch import torch.nn as nn @@ -211,7 +211,6 @@ def get_pfs_and_pbs( # Using all non-initial states, calculate the backward policy, and the logprobs # of those actions. if trajectories.conditioning is not None: - # We need to index the conditioning vector to broadcast over the states. cond_dim = (-1,) * len(trajectories.conditioning.shape) traj_len = trajectories.states.tensor.shape[0] @@ -242,8 +241,14 @@ def get_pfs_and_pbs( log_pb_trajectories_slice[~valid_actions.is_exit] = valid_log_pb_actions log_pb_trajectories[~trajectories.actions.is_dummy] = log_pb_trajectories_slice - assert log_pf_trajectories.shape == (trajectories.max_length, trajectories.n_trajectories) - assert log_pb_trajectories.shape == (trajectories.max_length, trajectories.n_trajectories) + assert log_pf_trajectories.shape == ( + trajectories.max_length, + trajectories.n_trajectories, + ) + assert log_pb_trajectories.shape == ( + trajectories.max_length, + trajectories.n_trajectories, + ) return log_pf_trajectories, log_pb_trajectories def get_trajectories_scores( @@ -252,15 +257,15 @@ def get_trajectories_scores( recalculate_all_logprobs: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Given a batch of trajectories, calculate forward & backward policy scores. - + Args: trajectories: Trajectories to evaluate. recalculate_all_logprobs: Whether to re-evaluate all logprobs. - + Returns: A tuple of float tensors of shape (n_trajectories,) containing the total log_pf, total log_pb, and the total log-likelihood of the trajectories. - + """ log_pf_trajectories, log_pb_trajectories = self.get_pfs_and_pbs( trajectories, recalculate_all_logprobs=recalculate_all_logprobs @@ -279,7 +284,7 @@ def get_trajectories_scores( torch.isinf(total_log_pb_trajectories) ): raise ValueError("Infinite logprobs found") - + assert total_log_pf_trajectories.shape == (trajectories.n_trajectories,) assert total_log_pb_trajectories.shape == (trajectories.n_trajectories,) return ( diff --git a/src/gfn/gflownet/detailed_balance.py b/src/gfn/gflownet/detailed_balance.py index f8ca618e..d5002220 100644 --- a/src/gfn/gflownet/detailed_balance.py +++ b/src/gfn/gflownet/detailed_balance.py @@ -6,7 +6,7 @@ from gfn.containers import Trajectories, Transitions from gfn.env import Env from gfn.gflownet.base import PFBasedGFlowNet -from gfn.modules import GFNModule, ScalarEstimator, ConditionalScalarEstimator +from gfn.modules import ConditionalScalarEstimator, GFNModule, ScalarEstimator from gfn.utils.common import has_log_probs from gfn.utils.handlers import ( has_conditioning_exception_handler, @@ -91,7 +91,7 @@ def get_scores( - If transitions have log_probs attribute, use them - this is usually for on-policy learning - Else, re-evaluate the log_probs using the current self.pf - this is usually for off-policy learning with replay buffer - + Returns: A tuple of three tensors of shapes (n_transitions,), representing the log probabilities of the actions, the log probabilities of the backward actions, and th scores. @@ -194,7 +194,7 @@ def get_scores( assert valid_log_pf_actions.shape == (transitions.n_transitions,) assert log_pb_actions.shape == (transitions.n_transitions,) - assert scores.shape == (transitions.n_transitions,) + assert scores.shape == (transitions.n_transitions,) return valid_log_pf_actions, log_pb_actions, scores def loss(self, env: Env, transitions: Transitions) -> torch.Tensor: diff --git a/src/gfn/gflownet/flow_matching.py b/src/gfn/gflownet/flow_matching.py index c093ee97..38072080 100644 --- a/src/gfn/gflownet/flow_matching.py +++ b/src/gfn/gflownet/flow_matching.py @@ -1,16 +1,16 @@ -from typing import Tuple, Any, Union +from typing import Any, Tuple, Union import torch from gfn.containers import Trajectories from gfn.env import Env from gfn.gflownet.base import GFlowNet -from gfn.modules import DiscretePolicyEstimator, ConditionalDiscretePolicyEstimator +from gfn.modules import ConditionalDiscretePolicyEstimator, DiscretePolicyEstimator from gfn.samplers import Sampler from gfn.states import DiscreteStates, States from gfn.utils.handlers import ( - no_conditioning_exception_handler, has_conditioning_exception_handler, + no_conditioning_exception_handler, ) @@ -109,7 +109,6 @@ def flow_matching_loss( ) if conditioning is not None: - # Mask out only valid conditioning elements. valid_backward_conditioning = conditioning[valid_backward_mask] valid_forward_conditioning = conditioning[valid_forward_mask] @@ -204,7 +203,9 @@ def loss( ) return fm_loss + self.alpha * rm_loss - def to_training_samples(self, trajectories: Trajectories) -> Union[ + def to_training_samples( + self, trajectories: Trajectories + ) -> Union[ Tuple[DiscreteStates, DiscreteStates, torch.Tensor, torch.Tensor], Tuple[DiscreteStates, DiscreteStates, None, None], Tuple[States, States, torch.Tensor, torch.Tensor], diff --git a/src/gfn/gflownet/sub_trajectory_balance.py b/src/gfn/gflownet/sub_trajectory_balance.py index b8c41688..bf9d4d3e 100644 --- a/src/gfn/gflownet/sub_trajectory_balance.py +++ b/src/gfn/gflownet/sub_trajectory_balance.py @@ -6,14 +6,15 @@ from gfn.containers import Trajectories from gfn.env import Env from gfn.gflownet.base import TrajectoryBasedGFlowNet -from gfn.modules import GFNModule, ScalarEstimator, ConditionalScalarEstimator +from gfn.modules import ConditionalScalarEstimator, GFNModule, ScalarEstimator from gfn.utils.handlers import ( has_conditioning_exception_handler, no_conditioning_exception_handler, ) - -ContributionsTensor = torch.Tensor # shape: [max_len * (1 + max_len) / 2, n_trajectories] +ContributionsTensor = ( + torch.Tensor +) # shape: [max_len * (1 + max_len) / 2, n_trajectories] CumulativeLogProbsTensor = torch.Tensor # shape: [max_length + 1, n_trajectories] LogStateFlowsTensor = torch.Tensor # shape: [max_length, n_trajectories] LogTrajectoriesTensor = torch.Tensor # shape: [max_length, n_trajectories] @@ -115,7 +116,7 @@ def cumulative_logprobs( trajectories: a batch of trajectories. log_p_trajectories: log probabilities of each transition in each trajectory. - Returns: Tensor of shape (max_length + 1, n_trajectories), containing the + Returns: Tensor of shape (max_length + 1, n_trajectories), containing the cumulative sum of log probabilities of each trajectory. """ return torch.cat( @@ -136,12 +137,12 @@ def calculate_preds( ) -> PredictionsTensor: """ Calculate the predictions tensor for the current sub-trajectory length. - + Args: log_pf_trajectories_cum: Tensor of shape (max_length + 1, n_trajectories) containing the cumulative log probabilities of the forward actions. log_state_flows: Tensor of shape (max_length, n_trajectories) containing the log state flows. i: The sub-trajectory length. - + Returns: The predictions tensor of shape (max_length + 1 - i, n_trajectories). """ current_log_state_flows = ( @@ -179,7 +180,7 @@ def calculate_targets( sink_states_mask: A mask tensor of shape (max_length, n_trajectories) representing sink states. full_mask: A mask tensor of shape (max_length, n_trajectories) representing full states. i: The sub-trajectory length. - + Returns: The targets tensor of shape (max_length + 1 - i, n_trajectories). """ targets = torch.full_like(preds, fill_value=-float("inf")) @@ -262,7 +263,7 @@ def calculate_masks( Args: log_state_flows: Tensor of shape (max_length, n_trajectories) containing the log state flows. trajectories: The trajectories data. - + Returns: a tuple of three mask tensors of shape (max_length, n_trajectories). """ sink_states_mask = log_state_flows == -float("inf") @@ -353,7 +354,7 @@ def get_equal_within_contributions( Args: trajectories: The trajectories data. all_scores: The scores tensor. - + Returns: The contributions tensor of shape (max_len * (1 + max_len) / 2, n_trajectories). """ del all_scores @@ -383,7 +384,7 @@ def get_equal_contributions( Args: trajectories: The trajectories data. all_scores: The scores tensor. - + Returns: The contributions tensor of shape (max_len * (1 + max_len) / 2, n_trajectories). """ is_done = trajectories.when_is_done @@ -402,7 +403,7 @@ def get_tb_contributions( Args: trajectories: The trajectories data. all_scores: The scores tensor. - + Returns: The contributions tensor of shape (max_len * (1 + max_len) / 2, n_trajectories). """ max_len = trajectories.max_length @@ -427,7 +428,7 @@ def get_modified_db_contributions( Args: trajectories: The trajectories data. all_scores: The scores tensor. - + Returns: The contributions tensor of shape (max_len * (1 + max_len) / 2, n_trajectories). """ del all_scores @@ -461,7 +462,7 @@ def get_geometric_within_contributions( Args: trajectories: The trajectories data. all_scores: The scores tensor. - + Returns: The contributions tensor of shape (max_len * (1 + max_len) / 2, n_trajectories). """ del all_scores diff --git a/src/gfn/gym/box.py b/src/gfn/gym/box.py index 5435f629..00b224e5 100644 --- a/src/gfn/gym/box.py +++ b/src/gfn/gym/box.py @@ -43,34 +43,28 @@ def __init__( exit_action=exit_action, ) - def make_random_states_tensor( - self, batch_shape: Tuple[int, ...] - ) -> torch.Tensor: + def make_random_states_tensor(self, batch_shape: Tuple[int, ...]) -> torch.Tensor: """Generates random states tensor of shape (*batch_shape, 2).""" return torch.rand(batch_shape + (2,), device=self.device) - def step( - self, states: States, actions: Actions - ) -> torch.Tensor: + def step(self, states: States, actions: Actions) -> torch.Tensor: """Step function for the Box environment. - + Args: states: States object representing the current states. actions: Actions object representing the actions to be taken. - + Returns the next states as tensor of shape (*batch_shape, 2). """ return states.tensor + actions.tensor - def backward_step( - self, states: States, actions: Actions - ) -> torch.Tensor: + def backward_step(self, states: States, actions: Actions) -> torch.Tensor: """Backward step function for the Box environment. Args: states: States object representing the current states. actions: Actions object representing the actions to be taken. - + Returns the previous states as tensor of shape (*batch_shape, 2). """ return states.tensor - actions.tensor @@ -78,7 +72,7 @@ def backward_step( @staticmethod def norm(x: torch.Tensor) -> torch.Tensor: """Computes the L2 norm of the input tensor along the last dimension. - + Args: x: Input tensor of shape (*batch_shape, 2). Returns: normalized tensor of shape `batch_shape`.""" @@ -126,10 +120,10 @@ def is_action_valid( def reward(self, final_states: States) -> torch.Tensor: """Reward is distance from the goal point. - + Args: final_states: States object representing the final states. - + Returns the reward tensor of shape `batch_shape`. """ R0, R1, R2 = (self.R0, self.R1, self.R2) @@ -137,7 +131,7 @@ def reward(self, final_states: States) -> torch.Tensor: reward = ( R0 + (0.25 < ax).prod(-1) * R1 + ((0.3 < ax) * (ax < 0.4)).prod(-1) * R2 ) - + assert reward.shape == final_states.batch_shape return reward diff --git a/src/gfn/gym/discrete_ebm.py b/src/gfn/gym/discrete_ebm.py index 8743db65..5823736d 100644 --- a/src/gfn/gym/discrete_ebm.py +++ b/src/gfn/gym/discrete_ebm.py @@ -16,13 +16,12 @@ class EnergyFunction(nn.Module, ABC): @abstractmethod def forward(self, states: torch.Tensor) -> torch.Tensor: """Forward pass of the energy function - + Args: states: tensor of states of shape (*batch_shape, *state_shape) - + Returns tensor of energies of shape (*batch_shape) """ - pass class IsingModel(EnergyFunction): @@ -46,7 +45,7 @@ def forward(self, states: torch.Tensor) -> torch.Tensor: Args: states: tensor of states of shape (*batch_shape, *state_shape) - + Returns tensor of energies of shape (*batch_shape) """ assert states.shape[-1] == self._state_shape @@ -133,10 +132,10 @@ def make_random_states_tensor(self, batch_shape: Tuple) -> torch.Tensor: def is_exit_actions(self, actions: torch.Tensor) -> torch.Tensor: """Determines if the actions are exit actions. - + Args: actions: tensor of actions of shape (*batch_shape, *action_shape) - + Returns tensor of booleans of shape (*batch_shape) """ return actions == self.n_actions - 1 @@ -147,7 +146,7 @@ def step(self, states: States, actions: Actions) -> torch.Tensor: Args: states: States object representing the current states. actions: Actions object representing the actions to be taken. - + Returns the next states as tensor of shape (*batch_shape, ndim). """ # First, we select that actions that replace a -1 with a 0. @@ -186,7 +185,7 @@ def reward(self, final_states: DiscreteStates) -> torch.Tensor: Args: final_states: DiscreteStates object representing the final states. - + Returns the reward as tensor of shape (*batch_shape). """ reward = torch.exp(self.log_reward(final_states)) @@ -195,10 +194,10 @@ def reward(self, final_states: DiscreteStates) -> torch.Tensor: def log_reward(self, final_states: DiscreteStates) -> torch.Tensor: """The energy weighted by alpha is our log reward. - + Args: final_states: DiscreteStates object representing the final states. - + Returns the log reward as tensor of shape (*batch_shape).""" raw_states = final_states.tensor canonical = 2 * raw_states - 1 @@ -209,10 +208,10 @@ def log_reward(self, final_states: DiscreteStates) -> torch.Tensor: def get_states_indices(self, states: DiscreteStates) -> torch.Tensor: """The chosen encoding is the following: -1 -> 0, 0 -> 1, 1 -> 2, then we convert to base 3 - + Args: states: DiscreteStates object representing the states. - + Returns the states indices as tensor of shape (*batch_shape). """ states_raw = states.tensor @@ -223,10 +222,10 @@ def get_states_indices(self, states: DiscreteStates) -> torch.Tensor: def get_terminating_states_indices(self, states: DiscreteStates) -> torch.Tensor: """Returns the indices of the terminating states. - + Args: states: DiscreteStates object representing the states. - + Returns the indices of the terminating states as tensor of shape (*batch_shape). """ states_raw = states.tensor diff --git a/src/gfn/gym/helpers/box_utils.py b/src/gfn/gym/helpers/box_utils.py index 4f87f47a..295ff0eb 100644 --- a/src/gfn/gym/helpers/box_utils.py +++ b/src/gfn/gym/helpers/box_utils.py @@ -258,7 +258,7 @@ def __init__( alpha_theta: torch.Tensor, beta_theta: torch.Tensor, ): - """ "Initializes the distribution. + """Initializes the distribution. Args: delta: the radius of the quarter disk. diff --git a/src/gfn/gym/line.py b/src/gfn/gym/line.py index 90f0aeaf..cf85cc70 100644 --- a/src/gfn/gym/line.py +++ b/src/gfn/gym/line.py @@ -43,15 +43,13 @@ def __init__( exit_action=exit_action, ) # sf is -inf by default. - def step( - self, states: States, actions: Actions - ) -> torch.Tensor: + def step(self, states: States, actions: Actions) -> torch.Tensor: """Take a step in the environment. - + Args: states: The current states. actions: The actions to take. - + Returns the new states after taking the actions as a tensor of shape (*batch_shape, 2). """ states.tensor[..., 0] = states.tensor[..., 0] + actions.tensor.squeeze( @@ -61,15 +59,13 @@ def step( assert states.tensor.shape == states.batch_shape + (2,) return states.tensor - def backward_step( - self, states: States, actions: Actions - ) -> torch.Tensor: + def backward_step(self, states: States, actions: Actions) -> torch.Tensor: """Take a step in the environment in the backward direction. Args: states: The current states. actions: The actions to take. - + Returns the new states after taking the actions as a tensor of shape (*batch_shape, 2). """ states.tensor[..., 0] = states.tensor[..., 0] - actions.tensor.squeeze( @@ -90,10 +86,10 @@ def is_action_valid( def log_reward(self, final_states: States) -> torch.Tensor: """Log reward log of the environment. - + Args: final_states: The final states of the environment. - + Returns the log reward as a tensor of shape `batch_shape`. """ s = final_states.tensor[..., 0] diff --git a/src/gfn/modules.py b/src/gfn/modules.py index a0ddfed2..4351b462 100644 --- a/src/gfn/modules.py +++ b/src/gfn/modules.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional, Any +from typing import Any import torch import torch.nn as nn @@ -74,11 +74,12 @@ def __init__( def forward(self, input: States | torch.Tensor) -> torch.Tensor: """Forward pass of the module. - + Args: input: The input to the module, as states or a tensor. - - Returns the output of the module, as a tensor of shape (*batch_shape, output_dim).""" + + Returns the output of the module, as a tensor of shape (*batch_shape, output_dim). + """ if isinstance(input, States): input = self.preprocessor(input) @@ -126,7 +127,7 @@ def to_probability_distribution( states: The states to use. module_output: The output of the module as a tensor of shape (*batch_shape, output_dim). **policy_kwargs: Keyword arguments to modify the distribution. - + Returns a distribution object. """ raise NotImplementedError @@ -257,11 +258,11 @@ def _forward_trunk( self, states: States, conditioning: torch.Tensor ) -> torch.Tensor: """Forward pass of the trunk of the module. - + Args: states: The input states. conditioning: The conditioning input. - + Returns the output of the trunk of the module, as a tensor of shape (*batch_shape, output_dim). """ state_out = self.module(self.preprocessor(states)) @@ -270,15 +271,13 @@ def _forward_trunk( return out - def forward( - self, states: States, conditioning: torch.tensor - ) -> torch.Tensor: + def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor: """Forward pass of the module. - + Args: states: The input states. conditioning: The conditioning input. - + Returns the output of the module, as a tensor of shape (*batch_shape, output_dim). """ out = self._forward_trunk(states, conditioning) @@ -308,15 +307,13 @@ def __init__( is_backward=is_backward, ) - def forward( - self, states: States, conditioning: torch.tensor - ) -> torch.Tensor: + def forward(self, states: States, conditioning: torch.tensor) -> torch.Tensor: """Forward pass of the module. - + Args: states: The input states. conditioning: The tensor for conditioning. - + Returns the output of the module, as a tensor of shape (*batch_shape, output_dim). """ out = self._forward_trunk(states, conditioning) @@ -342,7 +339,7 @@ def to_probability_distribution( states: The states to use. module_output: The output of the module as a tensor of shape (*batch_shape, output_dim). **policy_kwargs: Keyword arguments to modify the distribution. - + Returns a distribution object. """ raise NotImplementedError diff --git a/src/gfn/preprocessors.py b/src/gfn/preprocessors.py index 6754fabd..dfa3e2b1 100644 --- a/src/gfn/preprocessors.py +++ b/src/gfn/preprocessors.py @@ -2,6 +2,7 @@ from typing import Callable import torch + from gfn.states import States @@ -17,13 +18,12 @@ def __init__(self, output_dim: int) -> None: @abstractmethod def preprocess(self, states: States) -> torch.Tensor: """Transform the states to the input of the neural network. - + Args: states: The states to preprocess. - + Returns the preprocessed states as a tensor of shape (*batch_shape, output_dim). """ - pass def __call__(self, states: States) -> torch.Tensor: """Transform the states to the input of the neural network, calling the preprocess method.""" @@ -65,10 +65,10 @@ def __init__( def preprocess(self, states) -> torch.Tensor: """Preprocess the states by returning their unique indices. - + Args: states: The states to preprocess. - + Returns the unique indices of the states as a tensor of shape `batch_shape`. """ return self.get_states_indices(states).long().unsqueeze(-1) diff --git a/src/gfn/samplers.py b/src/gfn/samplers.py index 819620f0..eb224fbf 100644 --- a/src/gfn/samplers.py +++ b/src/gfn/samplers.py @@ -35,7 +35,7 @@ def sample_actions( save_estimator_outputs: bool = False, save_logprobs: bool = True, **policy_kwargs: Any, - ) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None]: + ) -> Tuple[Actions, torch.Tensor | None, torch.Tensor | None,]: """Samples actions from the given states. Args: diff --git a/src/gfn/utils/distributions.py b/src/gfn/utils/distributions.py index ecf541b7..e0d6d3b0 100644 --- a/src/gfn/utils/distributions.py +++ b/src/gfn/utils/distributions.py @@ -19,10 +19,10 @@ class UnsqueezedCategorical(Categorical): def sample(self, sample_shape=torch.Size()) -> torch.Tensor: """Sample actions with an unsqueezed final dimension. - + Args: sample_shape: The shape of the sample. - + Returns the sampled actions as a tensor of shape (*sample_shape, *batch_shape, 1). """ out = super().sample(sample_shape).unsqueeze(-1) @@ -31,10 +31,10 @@ def sample(self, sample_shape=torch.Size()) -> torch.Tensor: def log_prob(self, sample: torch.Tensor) -> torch.Tensor: """Returns the log probabilities of an unsqueezed sample. - + Args: sample: The sample of for which to compute the log probabilities. - + Returns the log probabilities of the sample as a tensor of shape (*sample_shape, *batch_shape). """ assert sample.shape[-1] == 1 diff --git a/src/gfn/utils/training.py b/src/gfn/utils/training.py index 40c61f11..a15a5f5e 100644 --- a/src/gfn/utils/training.py +++ b/src/gfn/utils/training.py @@ -14,7 +14,7 @@ def get_terminating_state_dist_pmf(env: Env, states: States) -> torch.Tensor: Args: env: The environment. states: The states to compute the distribution of. - + Returns the empirical distribution of the terminating states as a tensor of shape (n_terminating_states,). """ states_indices = env.get_terminating_states_indices(states).cpu().numpy().tolist() diff --git a/testing/test_parametrizations_and_losses.py b/testing/test_parametrizations_and_losses.py index edaece25..20431d23 100644 --- a/testing/test_parametrizations_and_losses.py +++ b/testing/test_parametrizations_and_losses.py @@ -19,7 +19,7 @@ BoxPFMLP, ) from gfn.modules import DiscretePolicyEstimator, ScalarEstimator -from gfn.utils.modules import MLP, DiscreteUniform, Tabular +from gfn.utils.modules import DiscreteUniform, MLP, Tabular N = 10 # Number of trajectories from sample_trajectories (changes tests globally). diff --git a/tutorials/examples/train_conditional.py b/tutorials/examples/train_conditional.py index 057ccd71..a46a769e 100644 --- a/tutorials/examples/train_conditional.py +++ b/tutorials/examples/train_conditional.py @@ -1,15 +1,25 @@ #!/usr/bin/env python +from argparse import ArgumentParser + import torch -from tqdm import tqdm from torch.optim import Adam -from argparse import ArgumentParser +from tqdm import tqdm -from gfn.utils.common import set_seed -from gfn.gflownet import TBGFlowNet, DBGFlowNet, FMGFlowNet, SubTBGFlowNet, ModifiedDBGFlowNet +from gfn.gflownet import ( + DBGFlowNet, + FMGFlowNet, + ModifiedDBGFlowNet, + SubTBGFlowNet, + TBGFlowNet, +) from gfn.gym import HyperGrid -from gfn.modules import ConditionalDiscretePolicyEstimator, ScalarEstimator, ConditionalScalarEstimator +from gfn.modules import ( + ConditionalDiscretePolicyEstimator, + ConditionalScalarEstimator, + ScalarEstimator, +) from gfn.utils import NeuralNet - +from gfn.utils.common import set_seed DEFAULT_SEED = 4444 @@ -168,7 +178,6 @@ def build_subTB_gflownet(env): def train(env, gflownet, seed): - torch.manual_seed(0) exploration_rate = 0.5 lr = 0.0005 @@ -180,16 +189,20 @@ def train(env, gflownet, seed): # Policy parameters and logZ/logF get independent LRs (logF/Z typically higher). if type(gflownet) is TBGFlowNet: optimizer = Adam(gflownet.pf_pb_parameters(), lr=lr) - optimizer.add_param_group({"params": gflownet.logz_parameters(), "lr": lr * 100}) + optimizer.add_param_group( + {"params": gflownet.logz_parameters(), "lr": lr * 100} + ) elif type(gflownet) is DBGFlowNet or type(gflownet) is SubTBGFlowNet: optimizer = Adam(gflownet.pf_pb_parameters(), lr=lr) - optimizer.add_param_group({"params": gflownet.logF_parameters(), "lr": lr * 100}) + optimizer.add_param_group( + {"params": gflownet.logF_parameters(), "lr": lr * 100} + ) elif type(gflownet) is FMGFlowNet or type(gflownet) is ModifiedDBGFlowNet: optimizer = Adam(gflownet.parameters(), lr=lr) else: print("What is this gflownet? {}".format(type(gflownet))) - n_iterations = int(10) # 1e4) + n_iterations = int(10) # 1e4) batch_size = int(1e4) print("+ Training Conditional {}!".format(type(gflownet))) @@ -244,7 +257,6 @@ def main(args): if __name__ == "__main__": - parser = ArgumentParser() parser.add_argument(