diff --git a/gflownet/policy/base.py b/gflownet/policy/base.py index eb98e9bd..5d26e0c1 100644 --- a/gflownet/policy/base.py +++ b/gflownet/policy/base.py @@ -1,13 +1,44 @@ -from abc import ABC, abstractmethod +""" +Base Policy class for GFlowNet policy models. +""" + +from typing import Tuple, Union import torch from omegaconf import OmegaConf +from omegaconf.dictconfig import DictConfig +from gflownet.envs.base import GFlowNetEnv from gflownet.utils.common import set_device, set_float_precision class Policy: - def __init__(self, config, env, device, float_precision, base=None): + def __init__( + self, + config: Union[dict, DictConfig], + env: GFlowNetEnv, + device: Union[str, torch.device], + float_precision: [int, torch.dtype], + base=None, + ): + """ + Base Policy class for a :class:`GFlowNetAgent`. + + Parameters + ---------- + config : dict or DictConfig + The configuration dictionary to set up the policy model. + env : GFlowNetEnv + The environment used to train the :class:`GFlowNetAgent`, used to extract + needed properties. + device : str or torch.device + The device to be passed to torch tensors. + float_precision : int or torch.dtype + The floating point precision to be passed to torch tensors. + base: Policy (optional) + A base policy to be used as backbone for the backward policy. + """ + config = self._get_config(config) # Device and float precision self.device = set_device(device) self.float = set_float_precision(float_precision) @@ -18,24 +49,49 @@ def __init__(self, config, env, device, float_precision, base=None): self.output_dim = len(self.fixed_output) # Optional base model self.base = base + # Policy type, defaults to uniform + self.type = config.get("type", "uniform") + # Checkpoint, defaults to None + self.checkpoint = config.get("checkpoint", None) + # Instantiate the model + self.model, self.is_model = self.make_model() + + @staticmethod + def _get_config(config: Union[dict, DictConfig]) -> Union[dict, DictConfig]: + """ + Returns a configuration dictionary, even if the input is None. - self.parse_config(config) - self.instantiate() + Parameters + ---------- + config : dict or DictConfig + The configuration dictionary to set up the policy model. It may be None, in + which an empty config is created and the defaults will be used. - def parse_config(self, config): - # If config is null, default to uniform + Returns + ------- + config : dict or DictConfig + The configuration dictionary to set up the policy model. + """ if config is None: config = OmegaConf.create() - self.type = config.get("type", "uniform") - self.checkpoint = config.get("checkpoint", None) + return config + + def make_model(self) -> Tuple[Union[torch.Tensor, torch.nn.Module], bool]: + """ + Instantiates the model of the policy. - def instantiate(self): + Returns + ------- + model : torch.tensor or torch.nn.Module + A tensor representing the output of the policy or a torch model. + is_model : bool + True if the policy is a model (for example, a neural network) and False if + it is a fixed tensor (for example to make a uniform distribution). + """ if self.type == "fixed": - self.model = self.fixed_distribution - self.is_model = False + return self.fixed_distribution, False elif self.type == "uniform": - self.model = self.uniform_distribution - self.is_model = False + return self.uniform_distribution, False else: raise "Policy model type not defined" diff --git a/gflownet/policy/cnn.py b/gflownet/policy/cnn.py index f8343d88..52693fc3 100644 --- a/gflownet/policy/cnn.py +++ b/gflownet/policy/cnn.py @@ -6,19 +6,33 @@ class CNNPolicy(Policy): - def __init__(self, config, env, device, float_precision, base=None): + def __init__(self, **kwargs): + config = self._get_config(kwargs["config"]) + # Shared weights, defaults to False + self.shared_weights = config.get("shared_weights", False) + # Reload checkpoint, defaults to False + self.reload_ckpt = config.get("reload_ckpt", False) + # CNN features: number of layers, number of channels, kernel sizes, strides + self.n_layers = config.get("n_layers", 3) + self.channels = config.get("channels", [16] * self.n_layers) + self.kernel_sizes = config.get("kernel_sizes", [(3, 3)] * self.n_layers) + self.strides = config.get("strides", [(1, 1)] * self.n_layers) + # Environment + # TODO: rethink whether storing the whole environment is needed self.env = env - super().__init__( - config=config, - env=env, - device=device, - float_precision=float_precision, - base=base, - ) + # Base init + super().__init__(**kwargs) - def make_cnn(self): + def make_model(self): """ - Defines an CNN with no top layer activation + Instantiates a CNN with no top layer activation. + + Returns + ------- + model : torch.nn.Module + A torch model containing the CNN. + is_model : bool + True because a CNN is a model. """ if self.shared_weights and self.base is not None: layers = list(self.base.model.children())[:-1] @@ -27,14 +41,15 @@ def make_cnn(self): ) model = nn.Sequential(*layers, last_layer).to(self.device) - return model + return model, True current_channels = 1 conv_module = nn.Sequential() if len(self.kernel_sizes) != self.n_layers: raise ValueError( - f"Inconsistent dimensions kernel_sizes != n_layers, {len(self.kernel_sizes)} != {self.n_layers}" + f"Inconsistent dimensions kernel_sizes != n_layers, " + "{len(self.kernel_sizes)} != {self.n_layers}" ) for i in range(self.n_layers): @@ -59,33 +74,19 @@ def make_cnn(self): in_channels = conv_module(dummy_input).numel() if in_channels >= 500_000: # TODO: this could better be handled raise RuntimeWarning( - "Input channels for the dense layer are too big, this will increase number of parameters" + "Input channels for the dense layer are too big, this will " + "increase number of parameters" ) except RuntimeError as e: raise RuntimeError( - "Failed during convolution operation. Ensure that the kernel sizes and strides are appropriate for the input dimensions." + "Failed during convolution operation. Ensure that the kernel sizes " + "and strides are appropriate for the input dimensions." ) from e model = nn.Sequential( conv_module, nn.Flatten(), nn.Linear(in_channels, self.output_dim) ) - return model.to(self.device) - - def parse_config(self, config): - super().parse_config(config) - if config is None: - config = OmegaConf.create() - self.checkpoint = config.get("checkpoint", None) - self.shared_weights = config.get("shared_weights", False) - self.reload_ckpt = config.get("reload_ckpt", False) - self.n_layers = config.get("n_layers", 3) - self.channels = config.get("channels", [16] * self.n_layers) - self.kernel_sizes = config.get("kernel_sizes", [(3, 3)] * self.n_layers) - self.strides = config.get("strides", [(1, 1)] * self.n_layers) - - def instantiate(self): - self.model = self.make_cnn() - self.is_model = True + return model.to(self.device), True def __call__(self, states): states = states.unsqueeze(1) # (batch_size, channels, height, width) diff --git a/gflownet/policy/mlp.py b/gflownet/policy/mlp.py index b90f6e52..8f4fbc80 100644 --- a/gflownet/policy/mlp.py +++ b/gflownet/policy/mlp.py @@ -5,27 +5,40 @@ class MLPPolicy(Policy): - def __init__(self, config, env, device, float_precision, base=None): - super().__init__( - config=config, - env=env, - device=device, - float_precision=float_precision, - base=base, - ) + def __init__(self, **kwargs): + config = self._get_config(kwargs["config"]) + # Shared weights, defaults to False + self.shared_weights = config.get("shared_weights", False) + # Reload checkpoint, defaults to False + self.reload_ckpt = config.get("reload_ckpt", False) + # MLP features: number of layers, number of hidden units, tail, etc. + self.n_layers = config.get("n_layers", 2) + self.n_hid = config.get("n_hid", 128) + self.tail = config.get("tail", []) + # Base init + super().__init__(**kwargs) - def make_mlp(self, activation): + def make_model(self, activation: nn.Module = nn.LeakyReLU()): """ - Defines an MLP with no top layer activation - If share_weight == True, - baseModel (the model with which weights are to be shared) must be provided - Args - ---- - layers_dim : list - Dimensionality of each layer - activation : Activation - Activation function + Instantiates an MLP with no top layer activation as the policy model. + + If self.shared_weights is True, the base model with which weights are to be + shared must be provided. + + Parameters + ---------- + activation : nn.Module + Activation function of the MLP layers + + Returns + ------- + model : torch.tensor or torch.nn.Module + A torch model containing the MLP. + is_model : bool + True because an MLP is a model. """ + activation.to(self.device) + if self.shared_weights == True and self.base is not None: mlp = nn.Sequential( self.base.model[:-1], @@ -33,7 +46,7 @@ def make_mlp(self, activation): self.base.model[-1].in_features, self.base.model[-1].out_features ), ) - return mlp + return mlp, True elif self.shared_weights == False: layers_dim = ( [self.state_dim] + [self.n_hid] * self.n_layers + [(self.output_dim)] @@ -53,26 +66,11 @@ def make_mlp(self, activation): + self.tail ) ) - return mlp + return mlp, True else: raise ValueError( "Base Model must be provided when shared_weights is set to True" ) - def parse_config(self, config): - super().parse_config(config) - if config is None: - config = OmegaConf.create() - self.checkpoint = config.get("checkpoint", None) - self.shared_weights = config.get("shared_weights", False) - self.n_hid = config.get("n_hid", 128) - self.n_layers = config.get("n_layers", 2) - self.tail = config.get("tail", []) - self.reload_ckpt = config.get("reload_ckpt", False) - - def instantiate(self): - self.model = self.make_mlp(nn.LeakyReLU()).to(self.device) - self.is_model = True - def __call__(self, states): return self.model(states)